diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 39b227a48ddea..95607f297c6bd 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -20,7 +20,8 @@ jobs: # Override exempt-all-assignees but only to exempt the issues with an assignee to be marked as stale automatically exempt-all-issue-assignees: true # Used to ignore the issues and pull requests created before the start date - start-date: 20220419 + # Start date should be April 19, 2022 - corresponds to the day previous stale bot stopped working + start-date: '2022-04-19T00:00:00Z' # Number of days without activity before the actions/stale action labels an issue days-before-issue-stale: 30 # Number of days without activity before the actions/stale action closes an issue diff --git a/.vscode/settings.json b/.vscode/settings.json index fd28e2d7b335c..c4a08e3232a82 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -16,25 +16,6 @@ }, // Enable Python linting and Pylance type checking "python.analysis.typeCheckingMode": "basic", - "python.formatting.provider": "black", - "python.formatting.blackArgs": [ - "--line-length", - "120" - ], - "python.sortImports.args": [ - "--profile", - "black", - "--line-length", - "120" - ], - "python.linting.enabled": true, - "python.linting.flake8Enabled": true, - "python.linting.pylintEnabled": true, - "python.linting.pydocstyleEnabled": true, - "python.linting.pydocstyleArgs": [ - "--convention=google" - ], - "python.linting.banditEnabled": true, "cpplint.lineLength": 120, "cpplint.filters": [ "-build/include_subdir", diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 0886a29fa573e..12fbb291c3a70 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -286,7 +286,7 @@ "component": { "type": "git", "git": { - "commitHash": "c4f6b8c6bc94ff69048492fb34df0dfaf1983933", + "commitHash": "6f47420213f757831fae65c686aa471749fa8d60", "repositoryUrl": "https://github.com/NVIDIA/cutlass.git" }, "comments": "cutlass" @@ -316,7 +316,7 @@ "component": { "type": "git", "git": { - "commitHash": "d52ec01652b7d620386251db92455968d8d90bdc", + "commitHash": "a4f72a314a85732ed67d5aa8d1088d207a7e0e61", "repositoryUrl": "https://github.com/ROCmSoftwarePlatform/composable_kernel.git" }, "comments": "composable_kernel" diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 94181448fd21c..5796db03fed7c 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -114,9 +114,7 @@ option(onnxruntime_ENABLE_LTO "Enable link time optimization" OFF) option(onnxruntime_CROSS_COMPILING "Cross compiling onnx runtime" OFF) option(onnxruntime_GCOV_COVERAGE "Compile with options necessary to run code coverage" OFF) option(onnxruntime_DONT_VECTORIZE "Do not vectorize operations in Eigen" OFF) - -#It's preferred to turn it OFF when onnxruntime is dynamically linked to PROTOBUF. But Tensort always required the full version of protobuf. -cmake_dependent_option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF "NOT onnxruntime_USE_TENSORRT" ON) +option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF) option(tensorflow_C_PACKAGE_PATH "Path to tensorflow C package installation dir") option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in language other than cpp" OFF) option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump debug information about node inputs and outputs when executing the model." OFF) @@ -526,7 +524,21 @@ if(NOT WIN32 AND NOT CMAKE_SYSTEM_NAME STREQUAL "Android") find_package(Iconv REQUIRED) set(ICONV_LIB Iconv::Iconv) endif() + find_package(Patch) +if (WIN32 AND NOT Patch_FOUND) + # work around CI machines missing patch from the git install by falling back to the binary in this repo. + # replicate what happens in https://github.com/Kitware/CMake/blob/master/Modules/FindPatch.cmake but without + # the hardcoded suffixes in the path to the patch binary. + find_program(Patch_EXECUTABLE NAMES patch PATHS ${PROJECT_SOURCE_DIR}/external/git.Win32.2.41.03.patch) + if(Patch_EXECUTABLE) + set(Patch_FOUND 1) + if (NOT TARGET Patch::patch) + add_executable(Patch::patch IMPORTED) + set_property(TARGET Patch::patch PROPERTY IMPORTED_LOCATION ${Patch_EXECUTABLE}) + endif() + endif() +endif() if(Patch_FOUND) message("Patch found: ${Patch_EXECUTABLE}") endif() diff --git a/cmake/deps.txt b/cmake/deps.txt index 275b5eaf6b976..49142372ab86e 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -51,7 +51,7 @@ pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/959002f82d7962a473d8b re2;https://github.com/google/re2/archive/refs/tags/2022-06-01.zip;aa77313b76e91b531ee7f3e45f004c6a502a5374 safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 -cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.0.0.zip;0f95b3c1fc1bd1175c4a90b2c9e39074d1bccefd +cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c -composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/d52ec01652b7d620386251db92455968d8d90bdc.zip;6b5ce8edf3625f8817086c194fbf94b664e1b0e0 +composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/a4f72a314a85732ed67d5aa8d1088d207a7e0e61.zip;f57357ab6d300e207a632d034ebc8aa036a090d9 diff --git a/cmake/external/abseil-cpp.natvis b/cmake/external/abseil-cpp.natvis index 708d6ba18750b..1e5a36fb9efb9 100644 --- a/cmake/external/abseil-cpp.natvis +++ b/cmake/external/abseil-cpp.natvis @@ -30,7 +30,6 @@ - empty size={ _size() } size=({_size()}) diff --git a/cmake/external/composable_kernel.cmake b/cmake/external/composable_kernel.cmake index 7168cd1a22c53..b4e6c834c83ab 100644 --- a/cmake/external/composable_kernel.cmake +++ b/cmake/external/composable_kernel.cmake @@ -12,13 +12,14 @@ if(NOT composable_kernel_POPULATED) FetchContent_Populate(composable_kernel) set(BUILD_DEV OFF CACHE BOOL "Disable -Weverything, otherwise, error: 'constexpr' specifier is incompatible with C++98 [-Werror,-Wc++98-compat]" FORCE) # Exclude i8 device gemm instances due to excessive long compilation time and not being used - set(DTYPES fp32 fp16 bf16) + set(DTYPES fp32 fp16 bf16 fp8) set(INSTANCES_ONLY ON) add_subdirectory(${composable_kernel_SOURCE_DIR} ${composable_kernel_BINARY_DIR} EXCLUDE_FROM_ALL) add_library(onnxruntime_composable_kernel_includes INTERFACE) target_include_directories(onnxruntime_composable_kernel_includes INTERFACE ${composable_kernel_SOURCE_DIR}/include + ${composable_kernel_BINARY_DIR}/include ${composable_kernel_SOURCE_DIR}/library/include) target_compile_definitions(onnxruntime_composable_kernel_includes INTERFACE __fp32__ __fp16__ __bf16__) endif() diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index 8c5d81d638ced..983eecdd88235 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -4,7 +4,6 @@ if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTIO cutlass URL ${DEP_URL_cutlass} URL_HASH SHA1=${DEP_SHA1_cutlass} - PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/cutlass/cutlass.patch ) FetchContent_GetProperties(cutlass) diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index c0f7ddc50eb98..b123adb624fa4 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -1,23 +1,14 @@ - if (onnxruntime_USE_PREINSTALLED_EIGEN) add_library(eigen INTERFACE) file(TO_CMAKE_PATH ${eigen_SOURCE_PATH} eigen_INCLUDE_DIRS) target_include_directories(eigen INTERFACE ${eigen_INCLUDE_DIRS}) else () - if (onnxruntime_USE_ACL) - FetchContent_Declare( - eigen - URL ${DEP_URL_eigen} - URL_HASH SHA1=${DEP_SHA1_eigen} - PATCH_COMMAND ${Patch_EXECUTABLE} --ignore-space-change --ignore-whitespace < ${PROJECT_SOURCE_DIR}/patches/eigen/Fix_Eigen_Build_Break.patch - ) - else() - FetchContent_Declare( - eigen - URL ${DEP_URL_eigen} - URL_HASH SHA1=${DEP_SHA1_eigen} - ) - endif() + FetchContent_Declare( + eigen + URL ${DEP_URL_eigen} + URL_HASH SHA1=${DEP_SHA1_eigen} + ) + FetchContent_Populate(eigen) set(eigen_INCLUDE_DIRS "${eigen_SOURCE_DIR}") endif() diff --git a/cmake/external/git.Win32.2.41.03.patch/msys-2.0.dll b/cmake/external/git.Win32.2.41.03.patch/msys-2.0.dll new file mode 100644 index 0000000000000..686afedb50bf3 Binary files /dev/null and b/cmake/external/git.Win32.2.41.03.patch/msys-2.0.dll differ diff --git a/cmake/external/git.Win32.2.41.03.patch/msys-gcc_s-1.dll b/cmake/external/git.Win32.2.41.03.patch/msys-gcc_s-1.dll new file mode 100644 index 0000000000000..1750b9ce92d72 Binary files /dev/null and b/cmake/external/git.Win32.2.41.03.patch/msys-gcc_s-1.dll differ diff --git a/cmake/external/git.Win32.2.41.03.patch/patch.exe b/cmake/external/git.Win32.2.41.03.patch/patch.exe new file mode 100644 index 0000000000000..8d784cb5d7e40 Binary files /dev/null and b/cmake/external/git.Win32.2.41.03.patch/patch.exe differ diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 019c6341d2e46..0fa5163dc06bf 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -335,6 +335,7 @@ if(onnxruntime_USE_CUDA) URL ${DEP_URL_microsoft_gsl} URL_HASH SHA1=${DEP_SHA1_microsoft_gsl} PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/gsl/1064.patch + FIND_PACKAGE_ARGS 4.0 NAMES Microsoft.GSL ) else() FetchContent_Declare( diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index a62b1b259d109..04efa5c2b4f6d 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -33,6 +33,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qpostprocessor.cpp ${MLAS_SRC_DIR}/qlgavgpool.cpp ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp + ${MLAS_SRC_DIR}/sqnbitgemm.cpp ) if (NOT onnxruntime_ORT_MINIMAL_BUILD) @@ -68,6 +69,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ) set(mlas_platform_preprocess_srcs @@ -334,6 +336,7 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ) if (NOT APPLE) set(mlas_platform_srcs diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index a9a78668b4810..345ef2b504aa4 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -339,9 +339,6 @@ configure_file(${ONNXRUNTIME_ROOT}/python/_pybind_state.py.in ${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py) if (onnxruntime_ENABLE_TRAINING) - file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/python/deprecated/*.py" - ) file(GLOB onnxruntime_python_root_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/*.py" ) @@ -419,10 +416,6 @@ if (onnxruntime_ENABLE_TRAINING) "${ORTTRAINING_SOURCE_DIR}/python/training/onnxblock/optim/*" ) endif() -else() - file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/python/training/*.py" - ) endif() if (onnxruntime_BUILD_UNIT_TESTS) @@ -443,6 +436,9 @@ if (onnxruntime_BUILD_UNIT_TESTS) file(GLOB onnxruntime_python_transformers_testdata_whisper CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/whisper/*.onnx" ) + file(GLOB onnxruntime_python_transformers_testdata_conformer CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/conformer/*.onnx" + ) endif() file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS @@ -556,6 +552,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models/whisper COMMAND ${CMAKE_COMMAND} -E make_directory $/eager_test + COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models/conformer COMMAND ${CMAKE_COMMAND} -E copy ${ONNXRUNTIME_ROOT}/__init__.py $/onnxruntime/ @@ -577,9 +574,6 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py $/onnxruntime/capi/ - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_capi_training_srcs} - $/onnxruntime/capi/training/ COMMAND ${CMAKE_COMMAND} -E copy $ $/onnxruntime/capi/ @@ -711,6 +705,9 @@ if (onnxruntime_BUILD_UNIT_TESTS) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_testdata_whisper} $/transformers/test_data/models/whisper/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_testdata_conformer} + $/transformers/test_data/models/conformer/ ) endif() @@ -750,9 +747,6 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils/data/ COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils/hooks/ - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_capi_training_srcs} - $/onnxruntime/capi/training/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_root_srcs} $/onnxruntime/training/ diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index ce27bf756f247..980bd59b22c3f 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -51,6 +51,7 @@ set(contrib_ops_excluded_files "math/gemm_float8.cc" "math/gemm_float8.cu" "math/gemm_float8.h" + "moe/*" "quantization/attention_quantization.cc" "quantization/attention_quantization.h" "quantization/attention_quantization_impl.cu" diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 12b439f7de7bc..98e237114dfbe 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -906,7 +906,7 @@ if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set_target_properties(onnxruntime_test_all PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js) - set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s ALLOW_MEMORY_GROWTH=1 --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" -s INCOMING_MODULE_JS_API=[preRun,locateFile] --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1 -s DEMANGLE_SUPPORT=1") + set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=4294967296 --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" -s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm] --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1 -s DEMANGLE_SUPPORT=1") if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY LINK_FLAGS " -s DEFAULT_PTHREAD_STACK_SIZE=131072 -s PROXY_TO_PTHREAD=1") endif() diff --git a/cmake/patches/composable_kernel/Fix_Clang_Build.patch b/cmake/patches/composable_kernel/Fix_Clang_Build.patch index d564ffba914fe..02b30af9eef52 100644 --- a/cmake/patches/composable_kernel/Fix_Clang_Build.patch +++ b/cmake/patches/composable_kernel/Fix_Clang_Build.patch @@ -1,17 +1,17 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index 514b98fde..59c8a568a 100644 +index b09da41a8..fca2bdf69 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -1,7 +1,7 @@ - cmake_minimum_required(VERSION 3.14) +@@ -19,7 +19,7 @@ endif() + set(version 1.1.0) # Check support for CUDA/HIP in Cmake --project(composable_kernel) -+project(composable_kernel LANGUAGES CXX HIP) +-project(composable_kernel VERSION ${version}) ++project(composable_kernel VERSION ${version} LANGUAGES CXX HIP) list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") -@@ -94,27 +94,6 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) +@@ -173,27 +173,6 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") @@ -39,7 +39,7 @@ index 514b98fde..59c8a568a 100644 ## HIP find_package(HIP REQUIRED) # Override HIP version in config.h, if necessary. -@@ -136,8 +115,6 @@ if( DEFINED CK_OVERRIDE_HIP_VERSION_PATCH ) +@@ -215,8 +194,6 @@ if( DEFINED CK_OVERRIDE_HIP_VERSION_PATCH ) message(STATUS "CK_HIP_VERSION_PATCH overriden with ${CK_OVERRIDE_HIP_VERSION_PATCH}") endif() message(STATUS "Build with HIP ${HIP_VERSION}") @@ -48,7 +48,7 @@ index 514b98fde..59c8a568a 100644 ## tidy include(EnableCompilerWarnings) -@@ -391,11 +368,3 @@ rocm_install(FILES +@@ -489,11 +466,3 @@ rocm_install(FILES set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") set(CPACK_RPM_PACKAGE_LICENSE "MIT") @@ -61,20 +61,21 @@ index 514b98fde..59c8a568a 100644 - HEADER_ONLY -) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt -index 1d54a141b..4edd7dbfb 100644 +index a0478c9f0..1e7782cd4 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt -@@ -1,7 +1,13 @@ - function(add_instance_library INSTANCE_NAME) - message("adding instance ${INSTANCE_NAME}") -+ set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) - add_library(${INSTANCE_NAME} OBJECT ${ARGN}) -+ # Always disable debug symbol and C debug assert due to -+ # - Linker error: ... relocation truncated to fit ..., caused by object files to be linked are too huge. -+ # - https://github.com/ROCmSoftwarePlatform/composable_kernel/issues/622 -+ target_compile_options(${INSTANCE_NAME} PRIVATE -g0 -DNDEBUG) - target_compile_features(${INSTANCE_NAME} PUBLIC) -+ target_compile_definitions(${INSTANCE_NAME} PRIVATE "__HIP_PLATFORM_AMD__=1" "__HIP_PLATFORM_HCC__=1") - set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) - clang_tidy_check(${INSTANCE_NAME}) - endfunction(add_instance_library INSTANCE_NAME) +@@ -44,8 +44,14 @@ function(add_instance_library INSTANCE_NAME) + endforeach() + #only continue if there are some source files left on the list + if(ARGN) ++ set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) + add_library(${INSTANCE_NAME} OBJECT ${ARGN}) ++ # Always disable debug symbol and C debug assert due to ++ # - Linker error: ... relocation truncated to fit ..., caused by object files to be linked are too huge. ++ # - https://github.com/ROCmSoftwarePlatform/composable_kernel/issues/622 ++ target_compile_options(${INSTANCE_NAME} PRIVATE -g0 -DNDEBUG) + target_compile_features(${INSTANCE_NAME} PUBLIC) ++ target_compile_definitions(${INSTANCE_NAME} PRIVATE "__HIP_PLATFORM_AMD__=1" "__HIP_PLATFORM_HCC__=1") + set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) + clang_tidy_check(${INSTANCE_NAME}) + set(result 0) diff --git a/cmake/patches/cutlass/cutlass.patch b/cmake/patches/cutlass/cutlass.patch deleted file mode 100644 index bda1de8b46916..0000000000000 --- a/cmake/patches/cutlass/cutlass.patch +++ /dev/null @@ -1,92 +0,0 @@ -diff --git a/include/cute/numeric/complex.hpp b/include/cute/numeric/complex.hpp -index 3790ebd3..cf727d09 100644 ---- a/include/cute/numeric/complex.hpp -+++ b/include/cute/numeric/complex.hpp -@@ -41,10 +41,14 @@ - // With CUDA 11.4, builds show spurious "-Wconversion" warnings - // on line 656 of thrust/detail/type_traits.h. - // These pragmas suppress the warnings. -+#ifdef __GNUC__ - #pragma GCC diagnostic push - #pragma GCC diagnostic ignored "-Wconversion" -+#endif - #include -+#ifdef __GNUC__ - #pragma GCC diagnostic pop -+#endif - - #include - -diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h -index 59aec46a..8f2a913a 100644 ---- a/include/cutlass/functional.h -+++ b/include/cutlass/functional.h -@@ -89,7 +89,7 @@ struct multiplies { - } - }; - --#if defined(__CUDA_ARCH__) -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - /// Partial specializations needed when __CUDA_NO_HALF2_OPERATORS__ is set - template<> - struct plus<__half2> { -@@ -143,12 +143,12 @@ struct multiplies<__half> { - - - // Maximum with nan propogation --// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN -+// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN - template - struct maximum_with_nan_propogation { - CUTLASS_HOST_DEVICE - T operator()(T const &lhs, T const &rhs) const { -- return lhs > rhs or std::isnan(lhs) ? lhs : rhs; -+ return lhs > rhs or isnan(lhs) ? lhs : rhs; - } - }; - -@@ -160,7 +160,7 @@ struct maximum_with_nan_propogation { - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs)); - #else -- res = lhs > rhs or std::isnan(lhs) ? lhs : rhs; -+ res = lhs > rhs or isnan(lhs) ? lhs : rhs; - #endif - return res; - } -@@ -233,7 +233,7 @@ struct negate { - } - }; - --/// Greater equal -+/// Greater equal - template - struct greater_equal { - CUTLASS_HOST_DEVICE -@@ -242,7 +242,7 @@ struct greater_equal { - } - }; - --/// Greater -+/// Greater - template - struct greater { - CUTLASS_HOST_DEVICE -@@ -251,7 +251,7 @@ struct greater { - } - }; - --/// Less equal -+/// Less equal - template - struct less_equal { - CUTLASS_HOST_DEVICE -@@ -260,7 +260,7 @@ struct less_equal { - } - }; - --/// Less -+/// Less - template - struct less { - CUTLASS_HOST_DEVICE diff --git a/cmake/patches/eigen/Fix_Eigen_Build_Break.patch b/cmake/patches/eigen/Fix_Eigen_Build_Break.patch deleted file mode 100644 index ca0e0fd23ddee..0000000000000 --- a/cmake/patches/eigen/Fix_Eigen_Build_Break.patch +++ /dev/null @@ -1,27 +0,0 @@ -diff -Naur git_org/cmake/external/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h git/cmake/external/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h ---- git_org/cmake/external/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h 2019-07-17 15:27:59.540667336 -0500 -+++ git/cmake/external/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h 2019-07-17 15:30:16.000000000 -0500 -@@ -1076,8 +1076,9 @@ - dest = *b; - } - -- EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacketx4& dest) const -- {} -+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const -+ { -+ } - - EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const - { -@@ -1145,8 +1146,9 @@ - loadRhs(b,dest); - } - -- EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacketx4& dest) const -- {} -+ EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const -+ { -+ } - - EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const - { diff --git a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch index 460b4d97c499b..736fffb1e384c 100644 --- a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch +++ b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index dba9b4687..bcaa18ad7 100755 +index dba9b4687..a4345898d 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -122,7 +122,7 @@ ENDIF() @@ -25,3 +25,15 @@ index dba9b4687..bcaa18ad7 100755 SET_TARGET_PROPERTIES(XNNPACK PROPERTIES C_EXTENSIONS YES) ENDIF() IF(NOT MSVC) +@@ -543,8 +548,9 @@ ENDIF() + IF(XNNPACK_TARGET_PROCESSOR STREQUAL "arm") + SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -marm ") + SET_PROPERTY(SOURCE ${PROD_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -marm ") +- SET_PROPERTY(SOURCE ${ALL_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv6 -mfpu=vfp -munaligned-access ") +- SET_PROPERTY(SOURCE ${PROD_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv6 -mfpu=vfp -munaligned-access ") ++ # set this to armv7-a to workaround build issue. we don't target armv6 so it shouldn't matter ++ SET_PROPERTY(SOURCE ${ALL_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=vfp -munaligned-access ") ++ SET_PROPERTY(SOURCE ${PROD_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=vfp -munaligned-access ") + SET_PROPERTY(SOURCE ${ALL_NEON_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=neon ") + SET_PROPERTY(SOURCE ${PROD_NEON_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=neon ") + SET_PROPERTY(SOURCE ${ALL_NEONFP16_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=neon-fp16 ") diff --git a/csharp/OnnxRuntime.CSharp.proj b/csharp/OnnxRuntime.CSharp.proj index 69bfd9896f1e4..5e43756ced7b1 100644 --- a/csharp/OnnxRuntime.CSharp.proj +++ b/csharp/OnnxRuntime.CSharp.proj @@ -92,6 +92,13 @@ CMake creates a target to this project + + + + + - @@ -109,9 +116,9 @@ CMake creates a target to this project @@ -119,11 +126,11 @@ CMake creates a target to this project diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index f722ca9d30fa4..4128524b30483 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -373,7 +373,7 @@ static NativeMethods() OrtAddSessionConfigEntry = (DOrtAddSessionConfigEntry)Marshal.GetDelegateForFunctionPointer(api_.AddSessionConfigEntry, typeof(DOrtAddSessionConfigEntry)); OrtAddInitializer = (DOrtAddInitializer)Marshal.GetDelegateForFunctionPointer(api_.AddInitializer, typeof(DOrtAddInitializer)); SessionOptionsAppendExecutionProvider_TensorRT = (DSessionOptionsAppendExecutionProvider_TensorRT)Marshal.GetDelegateForFunctionPointer( - api_.SessionOptionsAppendExecutionProvider_TensorRT, typeof(DSessionOptionsAppendExecutionProvider_TensorRT)); + api_.SessionOptionsAppendExecutionProvider_TensorRT, typeof(DSessionOptionsAppendExecutionProvider_TensorRT)); OrtCreateRunOptions = (DOrtCreateRunOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateRunOptions, typeof(DOrtCreateRunOptions)); OrtReleaseRunOptions = (DOrtReleaseRunOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseRunOptions, typeof(DOrtReleaseRunOptions)); @@ -487,27 +487,26 @@ static NativeMethods() OrtReleasePrepackedWeightsContainer = (DOrtReleasePrepackedWeightsContainer)Marshal.GetDelegateForFunctionPointer(api_.ReleasePrepackedWeightsContainer, typeof(DOrtReleasePrepackedWeightsContainer)); SessionOptionsAppendExecutionProvider_TensorRT_V2 = (DSessionOptionsAppendExecutionProvider_TensorRT_V2)Marshal.GetDelegateForFunctionPointer( - api_.SessionOptionsAppendExecutionProvider_TensorRT_V2, typeof(DSessionOptionsAppendExecutionProvider_TensorRT_V2)); + api_.SessionOptionsAppendExecutionProvider_TensorRT_V2, typeof(DSessionOptionsAppendExecutionProvider_TensorRT_V2)); OrtCreateTensorRTProviderOptions = (DOrtCreateTensorRTProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateTensorRTProviderOptions, typeof(DOrtCreateTensorRTProviderOptions)); OrtUpdateTensorRTProviderOptions = (DOrtUpdateTensorRTProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateTensorRTProviderOptions, typeof(DOrtUpdateTensorRTProviderOptions)); OrtGetTensorRTProviderOptionsAsString = (DOrtGetTensorRTProviderOptionsAsString)Marshal.GetDelegateForFunctionPointer(api_.GetTensorRTProviderOptionsAsString, typeof(DOrtGetTensorRTProviderOptionsAsString)); OrtReleaseTensorRTProviderOptions = (DOrtReleaseTensorRTProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseTensorRTProviderOptions, typeof(DOrtReleaseTensorRTProviderOptions)); SessionOptionsAppendExecutionProvider_CUDA = (DSessionOptionsAppendExecutionProvider_CUDA)Marshal.GetDelegateForFunctionPointer( - api_.SessionOptionsAppendExecutionProvider_CUDA, typeof(DSessionOptionsAppendExecutionProvider_CUDA)); + api_.SessionOptionsAppendExecutionProvider_CUDA, typeof(DSessionOptionsAppendExecutionProvider_CUDA)); SessionOptionsAppendExecutionProvider_CUDA_V2 = (DSessionOptionsAppendExecutionProvider_CUDA_V2)Marshal.GetDelegateForFunctionPointer( - api_.SessionOptionsAppendExecutionProvider_CUDA_V2, typeof(DSessionOptionsAppendExecutionProvider_CUDA_V2)); + api_.SessionOptionsAppendExecutionProvider_CUDA_V2, typeof(DSessionOptionsAppendExecutionProvider_CUDA_V2)); OrtCreateCUDAProviderOptions = (DOrtCreateCUDAProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateCUDAProviderOptions, typeof(DOrtCreateCUDAProviderOptions)); OrtUpdateCUDAProviderOptions = (DOrtUpdateCUDAProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateCUDAProviderOptions, typeof(DOrtUpdateCUDAProviderOptions)); OrtGetCUDAProviderOptionsAsString = (DOrtGetCUDAProviderOptionsAsString)Marshal.GetDelegateForFunctionPointer(api_.GetCUDAProviderOptionsAsString, typeof(DOrtGetCUDAProviderOptionsAsString)); OrtReleaseCUDAProviderOptions = (DOrtReleaseCUDAProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseCUDAProviderOptions, typeof(DOrtReleaseCUDAProviderOptions)); - SessionOptionsAppendExecutionProvider - = (DSessionOptionsAppendExecutionProvider)Marshal.GetDelegateForFunctionPointer( - api_.SessionOptionsAppendExecutionProvider, - typeof(DSessionOptionsAppendExecutionProvider)); + SessionOptionsAppendExecutionProvider = (DSessionOptionsAppendExecutionProvider)Marshal.GetDelegateForFunctionPointer( + api_.SessionOptionsAppendExecutionProvider, + typeof(DSessionOptionsAppendExecutionProvider)); OrtUpdateEnvWithCustomLogLevel = (DOrtUpdateEnvWithCustomLogLevel)Marshal.GetDelegateForFunctionPointer(api_.UpdateEnvWithCustomLogLevel, typeof(DOrtUpdateEnvWithCustomLogLevel)); SessionOptionsAppendExecutionProvider_ROCM = (DSessionOptionsAppendExecutionProvider_ROCM)Marshal.GetDelegateForFunctionPointer( - api_.SessionOptionsAppendExecutionProvider_ROCM, typeof(DSessionOptionsAppendExecutionProvider_ROCM)); + api_.SessionOptionsAppendExecutionProvider_ROCM, typeof(DSessionOptionsAppendExecutionProvider_ROCM)); OrtCreateROCMProviderOptions = (DOrtCreateROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateROCMProviderOptions, typeof(DOrtCreateROCMProviderOptions)); OrtUpdateROCMProviderOptions = (DOrtUpdateROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateROCMProviderOptions, typeof(DOrtUpdateROCMProviderOptions)); OrtGetROCMProviderOptionsAsString = (DOrtGetROCMProviderOptionsAsString)Marshal.GetDelegateForFunctionPointer(api_.GetROCMProviderOptionsAsString, typeof(DOrtGetROCMProviderOptionsAsString)); @@ -532,10 +531,10 @@ internal class NativeLib [DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)] public static extern ref OrtApiBase OrtGetApiBase(); - #region Runtime/Environment API +#region Runtime / Environment API [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtCreateEnv( + public delegate IntPtr /* OrtStatus* */ DOrtCreateEnv( OrtLoggingLevel defaultLoggingLevel, byte[] /*const char* */ logId, out IntPtr /*(OrtEnv*)*/ env); @@ -543,7 +542,7 @@ internal class NativeLib public static DOrtCreateEnv OrtCreateEnv; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtCreateEnvWithCustomLogger( + public delegate IntPtr /* OrtStatus* */ DOrtCreateEnvWithCustomLogger( IntPtr /* (OrtLoggingFunction*) */ loggingFunction, IntPtr /* (void*) */ loggerParam, OrtLoggingLevel defaultLoggingLevel, @@ -553,7 +552,7 @@ internal class NativeLib public static DOrtCreateEnvWithCustomLogger OrtCreateEnvWithCustomLogger; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtCreateEnvWithGlobalThreadPools( + public delegate IntPtr /* OrtStatus* */ DOrtCreateEnvWithGlobalThreadPools( OrtLoggingLevel defaultWarningLevel, byte[] /*const char* */ logId, IntPtr /*(const OrtThreadingOptions *) */ threadingOptions, @@ -564,7 +563,7 @@ internal class NativeLib [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus* */ DOrtCreateEnvWithCustomLoggerAndGlobalThreadPools( IntPtr /* OrtLoggingFunction */ loggingFunction, - IntPtr /* void* */loggerParam, + IntPtr /* void* */ loggerParam, OrtLoggingLevel logSeverityLevel, byte[] /* const char* */ logId, IntPtr /*(const OrtThreadingOptions *) */ threadingOptions, @@ -578,27 +577,27 @@ internal class NativeLib public static DOrtReleaseEnv OrtReleaseEnv; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtEnableTelemetryEvents(IntPtr /*(OrtEnv*)*/ env); + public delegate IntPtr /* OrtStatus* */ DOrtEnableTelemetryEvents(IntPtr /*(OrtEnv*)*/ env); public static DOrtEnableTelemetryEvents OrtEnableTelemetryEvents; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtDisableTelemetryEvents(IntPtr /*(OrtEnv*)*/ env); + public delegate IntPtr /* OrtStatus* */ DOrtDisableTelemetryEvents(IntPtr /*(OrtEnv*)*/ env); public static DOrtDisableTelemetryEvents OrtDisableTelemetryEvents; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtUpdateEnvWithCustomLogLevel(IntPtr /*(OrtEnv*)*/ env, OrtLoggingLevel custom_log_level); + public delegate IntPtr /* OrtStatus* */ DOrtUpdateEnvWithCustomLogLevel(IntPtr /*(OrtEnv*)*/ env, OrtLoggingLevel custom_log_level); public static DOrtUpdateEnvWithCustomLogLevel OrtUpdateEnvWithCustomLogLevel; - #endregion Runtime/Environment API +#endregion Runtime / Environment API - #region Provider Options API +#region Provider Options API /// /// Creates native OrtTensorRTProviderOptions instance /// /// (output) native instance of OrtTensorRTProviderOptions [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtCreateTensorRTProviderOptions( + public delegate IntPtr /* OrtStatus* */ DOrtCreateTensorRTProviderOptions( out IntPtr /*(OrtTensorRTProviderOptions**)*/ trtProviderOptionsInstance); public static DOrtCreateTensorRTProviderOptions OrtCreateTensorRTProviderOptions; @@ -610,7 +609,7 @@ internal class NativeLib /// configuration values of OrtTensorRTProviderOptions /// number of configuration keys [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtUpdateTensorRTProviderOptions( + public delegate IntPtr /* OrtStatus* */ DOrtUpdateTensorRTProviderOptions( IntPtr /*(OrtTensorRTProviderOptions*)*/ trtProviderOptionsInstance, IntPtr[] /*(const char* const *)*/ providerOptionsKeys, IntPtr[] /*(const char* const *)*/ providerOptionsValues, @@ -623,10 +622,10 @@ internal class NativeLib /// instance of OrtAllocator /// is a UTF-8 null terminated string allocated using 'allocator' [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtGetTensorRTProviderOptionsAsString( + public delegate IntPtr /* OrtStatus* */ DOrtGetTensorRTProviderOptionsAsString( IntPtr /*(OrtTensorRTProviderOptionsV2**)*/ trtProviderOptionsInstance, IntPtr /*(OrtAllocator*)*/ allocator, - out IntPtr /*(char**)*/ptr); + out IntPtr /*(char**)*/ ptr); public static DOrtGetTensorRTProviderOptionsAsString OrtGetTensorRTProviderOptionsAsString; /// @@ -642,7 +641,7 @@ internal class NativeLib /// /// (output) native instance of OrtCUDAProviderOptions [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtCreateCUDAProviderOptions( + public delegate IntPtr /* OrtStatus* */ DOrtCreateCUDAProviderOptions( out IntPtr /*(OrtCUDAProviderOptions**)*/ cudaProviderOptionsInstance); public static DOrtCreateCUDAProviderOptions OrtCreateCUDAProviderOptions; @@ -654,7 +653,7 @@ internal class NativeLib /// configuration values of OrtCUDAProviderOptions /// number of configuration keys [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtUpdateCUDAProviderOptions( + public delegate IntPtr /* OrtStatus* */ DOrtUpdateCUDAProviderOptions( IntPtr /*(OrtCUDAProviderOptions*)*/ cudaProviderOptionsInstance, IntPtr[] /*(const char* const *)*/ providerOptionsKeys, IntPtr[] /*(const char* const *)*/ providerOptionsValues, @@ -667,10 +666,10 @@ internal class NativeLib /// instance of OrtAllocator /// is a UTF-8 null terminated string allocated using 'allocator' [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtGetCUDAProviderOptionsAsString( + public delegate IntPtr /* OrtStatus* */ DOrtGetCUDAProviderOptionsAsString( IntPtr /*(OrtCUDAProviderOptionsV2**)*/ cudaProviderOptionsInstance, IntPtr /*(OrtAllocator*)*/ allocator, - out IntPtr /*(char**)*/ptr); + out IntPtr /*(char**)*/ ptr); public static DOrtGetCUDAProviderOptionsAsString OrtGetCUDAProviderOptionsAsString; /// @@ -686,7 +685,7 @@ internal class NativeLib /// /// (output) native instance of OrtROCMProviderOptions [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtCreateROCMProviderOptions( + public delegate IntPtr /* OrtStatus* */ DOrtCreateROCMProviderOptions( out IntPtr /*(OrtROCMProviderOptions**)*/ rocmProviderOptionsInstance); public static DOrtCreateROCMProviderOptions OrtCreateROCMProviderOptions; @@ -698,7 +697,7 @@ internal class NativeLib /// configuration values of OrtROCMProviderOptions /// number of configuration keys [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtUpdateROCMProviderOptions( + public delegate IntPtr /* OrtStatus* */ DOrtUpdateROCMProviderOptions( IntPtr /*(OrtROCMProviderOptions*)*/ rocmProviderOptionsInstance, IntPtr[] /*(const char* const *)*/ providerOptionsKeys, IntPtr[] /*(const char* const *)*/ providerOptionsValues, @@ -711,10 +710,10 @@ internal class NativeLib /// instance of OrtAllocator /// is a UTF-8 null terminated string allocated using 'allocator' [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtGetROCMProviderOptionsAsString( + public delegate IntPtr /* OrtStatus* */ DOrtGetROCMProviderOptionsAsString( IntPtr /*(OrtROCMProviderOptions**)*/ rocmProviderOptionsInstance, IntPtr /*(OrtAllocator*)*/ allocator, - out IntPtr /*(char**)*/ptr); + out IntPtr /*(char**)*/ ptr); public static DOrtGetROCMProviderOptionsAsString OrtGetROCMProviderOptionsAsString; /// @@ -725,34 +724,34 @@ internal class NativeLib public delegate void DOrtReleaseROCMProviderOptions(IntPtr /*(OrtROCMProviderOptions*)*/ rocmProviderOptionsInstance); public static DOrtReleaseROCMProviderOptions OrtReleaseROCMProviderOptions; - #endregion +#endregion - #region Status API +#region Status API [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate ErrorCode DOrtGetErrorCode(IntPtr /*(OrtStatus*)*/status); + public delegate ErrorCode DOrtGetErrorCode(IntPtr /*(OrtStatus*)*/ status); public static DOrtGetErrorCode OrtGetErrorCode; // returns char*, need to convert to string by the caller. // does not free the underlying OrtStatus* [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* char* */DOrtGetErrorMessage(IntPtr /* (OrtStatus*) */status); + public delegate IntPtr /* char* */ DOrtGetErrorMessage(IntPtr /* (OrtStatus*) */ status); public static DOrtGetErrorMessage OrtGetErrorMessage; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void DOrtReleaseStatus(IntPtr /*(OrtStatus*)*/ statusPtr); public static DOrtReleaseStatus OrtReleaseStatus; - #endregion Status API +#endregion Status API - #region InferenceSession API +#region InferenceSession API [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtCreateSession( - IntPtr /* (OrtEnv*) */ environment, - //[MarshalAs(UnmanagedType.LPStr)]string modelPath - byte[] modelPath, - IntPtr /* (OrtSessionOptions*) */sessopnOptions, - out IntPtr /**/ session); + public delegate IntPtr /* OrtStatus* */ DOrtCreateSession( + IntPtr /* (OrtEnv*) */ environment, + //[MarshalAs(UnmanagedType.LPStr)]string modelPath + byte[] modelPath, + IntPtr /* (OrtSessionOptions*) */ sessopnOptions, + out IntPtr /**/ session); public static DOrtCreateSession OrtCreateSession; @@ -765,22 +764,22 @@ internal class NativeLib /// Native OrtPrepackedWeightsContainer instance /// (Output) Created native OrtSession instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtCreateSessionWithPrepackedWeightsContainer( - IntPtr /* (OrtEnv*) */ environment, - byte[] modelPath, - IntPtr /* (OrtSessionOptions*) */sessionOptions, - IntPtr /* (OrtPrepackedWeightsContainer*) */prepackedWeightsContainer, - out IntPtr /* (OrtSession**) */ session); + public delegate IntPtr /* OrtStatus* */ DOrtCreateSessionWithPrepackedWeightsContainer( + IntPtr /* (OrtEnv*) */ environment, + byte[] modelPath, + IntPtr /* (OrtSessionOptions*) */ sessionOptions, + IntPtr /* (OrtPrepackedWeightsContainer*) */ prepackedWeightsContainer, + out IntPtr /* (OrtSession**) */ session); public static DOrtCreateSessionWithPrepackedWeightsContainer OrtCreateSessionWithPrepackedWeightsContainer; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtCreateSessionFromArray( - IntPtr /* (OrtEnv*) */ environment, - byte[] modelData, - UIntPtr modelSize, - IntPtr /* (OrtSessionOptions*) */ sessionOptions, - out IntPtr /**/ session); + public delegate IntPtr /* OrtStatus* */ DOrtCreateSessionFromArray( + IntPtr /* (OrtEnv*) */ environment, + byte[] modelData, + UIntPtr modelSize, + IntPtr /* (OrtSessionOptions*) */ sessionOptions, + out IntPtr /**/ session); public static DOrtCreateSessionFromArray OrtCreateSessionFromArray; /// @@ -793,169 +792,167 @@ internal class NativeLib /// Native OrtPrepackedWeightsContainer instance /// (Output) Created native OrtSession instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /* OrtStatus* */DOrtCreateSessionFromArrayWithPrepackedWeightsContainer( - IntPtr /* (OrtEnv*) */ environment, - byte[] /* (void*) */ modelData, - UIntPtr /* (size_t) */ modelSize, - IntPtr /* (OrtSessionOptions*) */ sessionOptions, - IntPtr /* (OrtPrepackedWeightsContainer*) */prepackedWeightsContainer, - out IntPtr /* (OrtSession**) */ session); + public delegate IntPtr /* OrtStatus* */ DOrtCreateSessionFromArrayWithPrepackedWeightsContainer( + IntPtr /* (OrtEnv*) */ environment, + byte[] /* (void*) */ modelData, + UIntPtr /* (size_t) */ modelSize, + IntPtr /* (OrtSessionOptions*) */ sessionOptions, + IntPtr /* (OrtPrepackedWeightsContainer*) */ prepackedWeightsContainer, + out IntPtr /* (OrtSession**) */ session); public static DOrtCreateSessionFromArrayWithPrepackedWeightsContainer OrtCreateSessionFromArrayWithPrepackedWeightsContainer; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(ONNStatus*)*/ DOrtRun( - IntPtr /*(OrtSession*)*/ session, - IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options - IntPtr[] inputNames, - IntPtr[] /* (OrtValue*[])*/ inputValues, - UIntPtr inputCount, - IntPtr[] outputNames, - UIntPtr outputCount, - IntPtr[] outputValues /* An array of output value pointers. Array must be allocated by the caller */ - ); + IntPtr /*(OrtSession*)*/ session, + IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options + IntPtr[] inputNames, + IntPtr[] /* (OrtValue*[])*/ inputValues, + UIntPtr inputCount, + IntPtr[] outputNames, + UIntPtr outputCount, + IntPtr[] outputValues /* An array of output value pointers. Array must be allocated by the caller */ + ); public static DOrtRun OrtRun; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(ONNStatus*)*/ DOrtRunWithBinding( - IntPtr /*(OrtSession*)*/ session, - IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can not be null - IntPtr /*(const OrtIoBinding*)*/ io_binding - ); + IntPtr /*(OrtSession*)*/ session, + IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can not be null + IntPtr /*(const OrtIoBinding*)*/ io_binding); public static DOrtRunWithBinding OrtRunWithBinding; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetInputCount( - IntPtr /*(OrtSession*)*/ session, - out UIntPtr count); + IntPtr /*(OrtSession*)*/ session, + out UIntPtr count); public static DOrtSessionGetInputCount OrtSessionGetInputCount; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetOutputCount( - IntPtr /*(OrtSession*)*/ session, - out UIntPtr count); + IntPtr /*(OrtSession*)*/ session, + out UIntPtr count); public static DOrtSessionGetOutputCount OrtSessionGetOutputCount; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetOverridableInitializerCount( - IntPtr /*(OrtSession*)*/ session, - out UIntPtr count); + IntPtr /*(OrtSession*)*/ session, + out UIntPtr count); public static DOrtSessionGetOverridableInitializerCount OrtSessionGetOverridableInitializerCount; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetInputName( - IntPtr /*(OrtSession*)*/ session, - UIntPtr index, - IntPtr /*(OrtAllocator*)*/ allocator, - out IntPtr /*(char**)*/name); + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetInputName( + IntPtr /*(OrtSession*)*/ session, + UIntPtr index, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(char**)*/ name); public static DOrtSessionGetInputName OrtSessionGetInputName; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetOutputName( - IntPtr /*(OrtSession*)*/ session, - UIntPtr index, - IntPtr /*(OrtAllocator*)*/ allocator, - out IntPtr /*(char**)*/name); + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetOutputName( + IntPtr /*(OrtSession*)*/ session, + UIntPtr index, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(char**)*/ name); public static DOrtSessionGetOutputName OrtSessionGetOutputName; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtSessionEndProfiling( - IntPtr /*(const OrtSession*)*/ session, - IntPtr /*(OrtAllocator*)*/ allocator, - out IntPtr /*(char**)*/profile_file); + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionEndProfiling( + IntPtr /*(const OrtSession*)*/ session, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(char**)*/ profile_file); public static DOrtSessionEndProfiling OrtSessionEndProfiling; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetOverridableInitializerName( - IntPtr /*(OrtSession*)*/ session, - UIntPtr index, - IntPtr /*(OrtAllocator*)*/ allocator, - out IntPtr /*(char**)*/name); + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetOverridableInitializerName( + IntPtr /*(OrtSession*)*/ session, + UIntPtr index, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(char**)*/ name); public static DOrtSessionGetOverridableInitializerName OrtSessionGetOverridableInitializerName; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetInputTypeInfo( - IntPtr /*(const OrtSession*)*/ session, - UIntPtr index, - out IntPtr /*(struct OrtTypeInfo**)*/ typeInfo); + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetInputTypeInfo( + IntPtr /*(const OrtSession*)*/ session, + UIntPtr index, + out IntPtr /*(struct OrtTypeInfo**)*/ typeInfo); public static DOrtSessionGetInputTypeInfo OrtSessionGetInputTypeInfo; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetOutputTypeInfo( - IntPtr /*(const OrtSession*)*/ session, - UIntPtr index, - out IntPtr /* (struct OrtTypeInfo**)*/ typeInfo); + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetOutputTypeInfo( + IntPtr /*(const OrtSession*)*/ session, + UIntPtr index, + out IntPtr /* (struct OrtTypeInfo**)*/ typeInfo); public static DOrtSessionGetOutputTypeInfo OrtSessionGetOutputTypeInfo; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetOverridableInitializerTypeInfo( - IntPtr /*(const OrtSession*)*/ session, - UIntPtr index, - out IntPtr /* (struct OrtTypeInfo**)*/ typeInfo); + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetOverridableInitializerTypeInfo( + IntPtr /*(const OrtSession*)*/ session, + UIntPtr index, + out IntPtr /* (struct OrtTypeInfo**)*/ typeInfo); public static DOrtSessionGetOverridableInitializerTypeInfo OrtSessionGetOverridableInitializerTypeInfo; // release the typeinfo using OrtReleaseTypeInfo [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate void DOrtReleaseTypeInfo(IntPtr /*(OrtTypeInfo*)*/session); + public delegate void DOrtReleaseTypeInfo(IntPtr /*(OrtTypeInfo*)*/ session); public static DOrtReleaseTypeInfo OrtReleaseTypeInfo; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate void DOrtReleaseSession(IntPtr /*(OrtSession*)*/session); + public delegate void DOrtReleaseSession(IntPtr /*(OrtSession*)*/ session); public static DOrtReleaseSession OrtReleaseSession; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetProfilingStartTimeNs( - IntPtr /*(const OrtSession*)*/ session, - out UIntPtr /*(ulong* out)*/ startTime); + IntPtr /*(const OrtSession*)*/ session, + out UIntPtr /*(ulong* out)*/ startTime); public static DOrtSessionGetProfilingStartTimeNs OrtSessionGetProfilingStartTimeNs; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(ONNStatus*)*/ DCreateAndRegisterAllocatorV2( - IntPtr /* (OrtEnv*) */ environment, - IntPtr /*(char*)*/ provider_type, - IntPtr /*(OrtMemoryInfo*)*/ mem_info, - IntPtr /*(OrtArenaCfg*)*/ arena_cfg, - IntPtr /*(char**)*/ provider_options_keys, - IntPtr /*(char**)*/ provider_options_values, - UIntPtr /*(size_t)*/num_keys); + IntPtr /* (OrtEnv*) */ environment, + IntPtr /*(char*)*/ provider_type, + IntPtr /*(OrtMemoryInfo*)*/ mem_info, + IntPtr /*(OrtArenaCfg*)*/ arena_cfg, + IntPtr /*(char**)*/ provider_options_keys, + IntPtr /*(char**)*/ provider_options_values, + UIntPtr /*(size_t)*/ num_keys); public static DCreateAndRegisterAllocatorV2 OrtCreateAndRegisterAllocatorV2; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(ONNStatus*)*/ DOrtRunAsync( - IntPtr /*(OrtSession*)*/ session, - IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options - IntPtr[] /*(char**)*/ inputNames, - IntPtr[] /*(OrtValue*[])*/ inputValues, - UIntPtr /*(size_t)*/ inputCount, - IntPtr[] /*(char**)*/ outputNames, - UIntPtr /*(size_t)*/ outputCount, - IntPtr[] /*(OrtValue*[])*/ outputValues, - IntPtr /*(void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status))*/ callback, // callback function - IntPtr /*(void*)*/ user_data - ); + IntPtr /*(OrtSession*)*/ session, + IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options + IntPtr[] /*(char**)*/ inputNames, + IntPtr[] /*(OrtValue*[])*/ inputValues, + UIntPtr /*(size_t)*/ inputCount, + IntPtr[] /*(char**)*/ outputNames, + UIntPtr /*(size_t)*/ outputCount, + IntPtr[] /*(OrtValue*[])*/ outputValues, + IntPtr /*(void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status))*/ callback, // callback function + IntPtr /*(void*)*/ user_data); public static DOrtRunAsync OrtRunAsync; - #endregion InferenceSession API +#endregion InferenceSession API - #region SessionOptions API +#region SessionOptions API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateSessionOptions(out IntPtr /*(OrtSessionOptions**)*/ sessionOptions); public static DOrtCreateSessionOptions OrtCreateSessionOptions; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate void DOrtReleaseSessionOptions(IntPtr /*(OrtSessionOptions*)*/session); + public delegate void DOrtReleaseSessionOptions(IntPtr /*(OrtSessionOptions*)*/ session); public static DOrtReleaseSessionOptions OrtReleaseSessionOptions; [UnmanagedFunctionPointer(CallingConvention.Winapi)] @@ -964,7 +961,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionExecutionMode(IntPtr /*(OrtSessionOptions*)*/ options, - ExecutionMode execution_mode); + ExecutionMode execution_mode); public static DOrtSetSessionExecutionMode OrtSetSessionExecutionMode; [UnmanagedFunctionPointer(CallingConvention.Winapi)] @@ -996,7 +993,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DOrtDisableCpuMemArena OrtDisableCpuMemArena; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionLogId(IntPtr /* OrtSessionOptions* */ options, byte[] /* const char* */logId); + public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionLogId(IntPtr /* OrtSessionOptions* */ options, byte[] /* const char* */ logId); public static DOrtSetSessionLogId OrtSetSessionLogId; [UnmanagedFunctionPointer(CallingConvention.Winapi)] @@ -1027,7 +1024,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Config value [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtAddSessionConfigEntry(IntPtr /* OrtSessionOptions* */ options, - byte[] /* const char* */configKey, + byte[] /* const char* */ configKey, byte[] /* const char* */ configValue); public static DOrtAddSessionConfigEntry OrtAddSessionConfigEntry; @@ -1090,9 +1087,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Native OrtSessionOptions instance /// Native OrtTensorRTProviderOptions instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DSessionOptionsAppendExecutionProvider_TensorRT( - IntPtr /*(OrtSessionOptions*)*/ options, - IntPtr /*(const OrtTensorRTProviderOptions*)*/ trtProviderOptions); + public delegate IntPtr /*(OrtStatus*)*/ DSessionOptionsAppendExecutionProvider_TensorRT( + IntPtr /*(OrtSessionOptions*)*/ options, + IntPtr /*(const OrtTensorRTProviderOptions*)*/ trtProviderOptions); public static DSessionOptionsAppendExecutionProvider_TensorRT SessionOptionsAppendExecutionProvider_TensorRT; @@ -1102,9 +1099,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Native OrtSessionOptions instance /// Native OrtTensorRTProviderOptionsV2 instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DSessionOptionsAppendExecutionProvider_TensorRT_V2( - IntPtr /*(OrtSessionOptions*)*/ options, - IntPtr /*(const OrtTensorRTProviderOptionsV2*)*/ trtProviderOptions); + public delegate IntPtr /*(OrtStatus*)*/ DSessionOptionsAppendExecutionProvider_TensorRT_V2( + IntPtr /*(OrtSessionOptions*)*/ options, + IntPtr /*(const OrtTensorRTProviderOptionsV2*)*/ trtProviderOptions); public static DSessionOptionsAppendExecutionProvider_TensorRT_V2 SessionOptionsAppendExecutionProvider_TensorRT_V2; @@ -1114,9 +1111,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Native OrtSessionOptions instance /// Native OrtCUDAProviderOptions instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DSessionOptionsAppendExecutionProvider_CUDA( - IntPtr /*(OrtSessionOptions*)*/ options, - IntPtr /*(const OrtCUDAProviderOptions*)*/ cudaProviderOptions); + public delegate IntPtr /*(OrtStatus*)*/ DSessionOptionsAppendExecutionProvider_CUDA( + IntPtr /*(OrtSessionOptions*)*/ options, + IntPtr /*(const OrtCUDAProviderOptions*)*/ cudaProviderOptions); public static DSessionOptionsAppendExecutionProvider_CUDA SessionOptionsAppendExecutionProvider_CUDA; @@ -1126,9 +1123,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Native OrtSessionOptions instance /// Native OrtCUDAProviderOptionsV2 instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DSessionOptionsAppendExecutionProvider_CUDA_V2( - IntPtr /*(OrtSessionOptions*)*/ options, - IntPtr /*(const OrtCUDAProviderOptionsV2*)*/ cudaProviderOptions); + public delegate IntPtr /*(OrtStatus*)*/ DSessionOptionsAppendExecutionProvider_CUDA_V2( + IntPtr /*(OrtSessionOptions*)*/ options, + IntPtr /*(const OrtCUDAProviderOptionsV2*)*/ cudaProviderOptions); public static DSessionOptionsAppendExecutionProvider_CUDA_V2 SessionOptionsAppendExecutionProvider_CUDA_V2; @@ -1138,9 +1135,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Native OrtSessionOptions instance /// Native OrtROCMProviderOptions instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DSessionOptionsAppendExecutionProvider_ROCM( - IntPtr /*(OrtSessionOptions*)*/ options, - IntPtr /*(const OrtROCMProviderOptions*)*/ rocmProviderOptions); + public delegate IntPtr /*(OrtStatus*)*/ DSessionOptionsAppendExecutionProvider_ROCM( + IntPtr /*(OrtSessionOptions*)*/ options, + IntPtr /*(const OrtROCMProviderOptions*)*/ rocmProviderOptions); public static DSessionOptionsAppendExecutionProvider_ROCM SessionOptionsAppendExecutionProvider_ROCM; @@ -1151,9 +1148,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Dimension denotation /// Dimension value [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtAddFreeDimensionOverride(IntPtr /*(OrtSessionOptions*)*/ options, - byte[] /*(const char*)*/ dimDenotation, - long dimValue); + public delegate IntPtr /*(OrtStatus*)*/ DOrtAddFreeDimensionOverride(IntPtr /*(OrtSessionOptions*)*/ options, + byte[] /*(const char*)*/ dimDenotation, + long dimValue); public static DOrtAddFreeDimensionOverride OrtAddFreeDimensionOverride; @@ -1164,9 +1161,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Dimension name /// Dimension value [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtAddFreeDimensionOverrideByName(IntPtr /*(OrtSessionOptions*)*/ options, - byte[] /*(const char*)*/ dimName, - long dimValue); + public delegate IntPtr /*(OrtStatus*)*/ DOrtAddFreeDimensionOverrideByName(IntPtr /*(OrtSessionOptions*)*/ options, + byte[] /*(const char*)*/ dimName, + long dimValue); public static DOrtAddFreeDimensionOverrideByName OrtAddFreeDimensionOverrideByName; @@ -1177,9 +1174,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Library path /// (out) Native library handle [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtRegisterCustomOpsLibrary(IntPtr /*(OrtSessionOptions*) */ options, - byte[] /*(const char*)*/ libraryPath, - out IntPtr /*(void**)*/ libraryHandle); + public delegate IntPtr /*(OrtStatus*)*/ DOrtRegisterCustomOpsLibrary(IntPtr /*(OrtSessionOptions*) */ options, + byte[] /*(const char*)*/ libraryPath, + out IntPtr /*(void**)*/ libraryHandle); public static DOrtRegisterCustomOpsLibrary OrtRegisterCustomOpsLibrary; @@ -1189,8 +1186,8 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Native SessionOptions instance /// Library path [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtRegisterCustomOpsLibrary_V2(IntPtr /*(OrtSessionOptions*) */ options, - byte[] /*(const ORTCHAR_T*)*/ libraryPath); + public delegate IntPtr /*(OrtStatus*)*/ DOrtRegisterCustomOpsLibrary_V2(IntPtr /*(OrtSessionOptions*) */ options, + byte[] /*(const ORTCHAR_T*)*/ libraryPath); public static DOrtRegisterCustomOpsLibrary_V2 OrtRegisterCustomOpsLibrary_V2; @@ -1201,9 +1198,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Name of the initializer /// Native OrtValue instnce [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtAddInitializer(IntPtr /*(OrtSessionOptions*)*/ options, - byte[] /*(const char*)*/ name, - IntPtr /*(OrtValue*)*/ ortValue); + public delegate IntPtr /*(OrtStatus*)*/ DOrtAddInitializer(IntPtr /*(OrtSessionOptions*)*/ options, + byte[] /*(const char*)*/ name, + IntPtr /*(OrtValue*)*/ ortValue); public static DOrtAddInitializer OrtAddInitializer; @@ -1220,25 +1217,25 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Configuration values to add /// Number of configuration keys [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DSessionOptionsAppendExecutionProvider( - IntPtr /*(OrtSessionOptions*)*/ options, - byte[] /*(const char*)*/ providerName, - IntPtr[] /*(const char* const *)*/ providerOptionsKeys, - IntPtr[] /*(const char* const *)*/ providerOptionsValues, - UIntPtr /*(size_t)*/ numKeys); + public delegate IntPtr /*(OrtStatus*)*/ DSessionOptionsAppendExecutionProvider( + IntPtr /*(OrtSessionOptions*)*/ options, + byte[] /*(const char*)*/ providerName, + IntPtr[] /*(const char* const *)*/ providerOptionsKeys, + IntPtr[] /*(const char* const *)*/ providerOptionsValues, + UIntPtr /*(size_t)*/ numKeys); public static DSessionOptionsAppendExecutionProvider SessionOptionsAppendExecutionProvider; - #endregion +#endregion - #region RunOptions API +#region RunOptions API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateRunOptions(out IntPtr /* OrtRunOptions** */ runOptions); public static DOrtCreateRunOptions OrtCreateRunOptions; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate void DOrtReleaseRunOptions(IntPtr /*(OrtRunOptions*)*/options); + public delegate void DOrtReleaseRunOptions(IntPtr /*(OrtRunOptions*)*/ options); public static DOrtReleaseRunOptions OrtReleaseRunOptions; [UnmanagedFunctionPointer(CallingConvention.Winapi)] @@ -1259,11 +1256,11 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtRunOptionsGetRunLogSeverityLevel(IntPtr /* OrtRunOptions* */ options, - out OrtLoggingLevel severityLevel); + out OrtLoggingLevel severityLevel); public static DOrtRunOptionsGetRunLogSeverityLevel OrtRunOptionsGetRunLogSeverityLevel; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtRunOptionsGetRunTag(IntPtr /* const OrtRunOptions* */options, out IntPtr /* const char** */ runtag); + public delegate IntPtr /*(OrtStatus*)*/ DOrtRunOptionsGetRunTag(IntPtr /* const OrtRunOptions* */ options, out IntPtr /* const char** */ runtag); public static DOrtRunOptionsGetRunTag OrtRunOptionsGetRunTag; // Set a flag so that any running OrtRun* calls that are using this instance of OrtRunOptions @@ -1276,7 +1273,6 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public delegate IntPtr /*(OrtStatus*)*/ DOrtRunOptionsUnsetTerminate(IntPtr /* OrtRunOptions* */ options); public static DOrtRunOptionsUnsetTerminate OrtRunOptionsUnsetTerminate; - /// /// Add run config entry /// @@ -1285,13 +1281,13 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Config value [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtAddRunConfigEntry(IntPtr /* OrtRunOptions* */ options, - byte[] /* const char* */configKey, + byte[] /* const char* */ configKey, byte[] /* const char* */ configValue); public static DOrtAddRunConfigEntry OrtAddRunConfigEntry; - #endregion +#endregion - #region ThreadingOptions API +#region ThreadingOptions API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateThreadingOptions(out IntPtr /* OrtCreateThreadingOptions** */ threadingOptions); @@ -1316,27 +1312,26 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtThreadingOptionsSetGlobalSpinControl(IntPtr /* OrtThreadingOptions* */ threadingOptions, int allowSpinning); public static DOrtThreadingOptionsSetGlobalSpinControl OrtThreadingOptionsSetGlobalSpinControl; - #endregion +#endregion - #region Allocator/MemoryInfo API +#region Allocator / MemoryInfo API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*)*/ DOrtCreateMemoryInfo( - byte[] /*(const char*) */name, - OrtAllocatorType allocatorType, - int identifier, - OrtMemType memType, - out IntPtr /*(OrtMemoryInfo*)*/ allocatorInfo // memory ownership transfered to caller - ); + byte[] /*(const char*) */ name, + OrtAllocatorType allocatorType, + int identifier, + OrtMemType memType, + out IntPtr /*(OrtMemoryInfo*)*/ allocatorInfo // memory ownership transfered to caller + ); public static DOrtCreateMemoryInfo OrtCreateMemoryInfo; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*)*/ DOrtCreateCpuMemoryInfo( - OrtAllocatorType allocatorType, - OrtMemType memoryType, - out IntPtr /*(OrtMemoryInfo*)*/ allocatorInfo - ); + OrtAllocatorType allocatorType, + OrtMemType memoryType, + out IntPtr /*(OrtMemoryInfo*)*/ allocatorInfo); public static DOrtCreateCpuMemoryInfo OrtCreateCpuMemoryInfo; @@ -1347,15 +1342,15 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCompareMemoryInfo( - IntPtr /*(const OrtMemoryInfo*)*/ info1, - IntPtr /*(const OrtMemoryInfo*)*/ info2, - out int /*(int* out)*/ result); + IntPtr /*(const OrtMemoryInfo*)*/ info1, + IntPtr /*(const OrtMemoryInfo*)*/ info2, + out int /*(int* out)*/ result); public static DOrtCompareMemoryInfo OrtCompareMemoryInfo; /** - * Do not free the returned value - */ + * Do not free the returned value + */ [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtMemoryInfoGetName(IntPtr /*(const OrtMemoryInfo* ptr)*/ mem_info, out IntPtr /*(const char**)*/ name); @@ -1368,26 +1363,25 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtMemoryInfoGetMemType( - IntPtr /*(const OrtMemoryInfo* ptr)*/ mem_info, - out OrtMemType /*(OrtMemType*)*/ mem_type); + IntPtr /*(const OrtMemoryInfo* ptr)*/ mem_info, + out OrtMemType /*(OrtMemType*)*/ mem_type); public static DOrtMemoryInfoGetMemType OrtMemoryInfoGetMemType; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtMemoryInfoGetType( - IntPtr /*(const OrtMemoryInfo* ptr)*/ mem_info, - out OrtAllocatorType /*(OrtAllocatorType*)*/ alloc_type - ); + IntPtr /*(const OrtMemoryInfo* ptr)*/ mem_info, + out OrtAllocatorType /*(OrtAllocatorType*)*/ alloc_type); public static DOrtMemoryInfoGetType OrtMemoryInfoGetType; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtGetAllocatorWithDefaultOptions(out IntPtr /*(OrtAllocator**)*/ allocator); + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetAllocatorWithDefaultOptions(out IntPtr /*(OrtAllocator**)*/ allocator); public static DOrtGetAllocatorWithDefaultOptions OrtGetAllocatorWithDefaultOptions; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/DOrtAllocatorGetInfo(IntPtr /*(const OrtAllocator*)*/ ptr, out IntPtr /*(const struct OrtMemoryInfo**)*/info); + public delegate IntPtr /*(OrtStatus*)*/ DOrtAllocatorGetInfo(IntPtr /*(const OrtAllocator*)*/ ptr, out IntPtr /*(const struct OrtMemoryInfo**)*/ info); public static DOrtAllocatorGetInfo OrtAllocatorGetInfo; @@ -1402,8 +1396,8 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// Pointer to a native OrtStatus instance indicating success/failure of config creation [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateArenaCfg(UIntPtr /*(size_t)*/ maxMemory, int /*(int)*/ arenaExtendStrategy, - int /*(int)*/ initialChunkSizeBytes, int /*(int)*/ maxDeadBytesPerChunk, - out IntPtr /*(OrtArenaCfg**)*/ arenaCfg); + int /*(int)*/ initialChunkSizeBytes, int /*(int)*/ maxDeadBytesPerChunk, + out IntPtr /*(OrtArenaCfg**)*/ arenaCfg); public static DOrtCreateArenaCfg OrtCreateArenaCfg; @@ -1457,9 +1451,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DOrtAllocatorFree OrtAllocatorFree; - #endregion Allocator/MemoryInfo API +#endregion Allocator / MemoryInfo API - #region IoBinding API +#region IoBinding API /// /// Create OrtIoBinding instance that is used to bind memory that is allocated @@ -1634,7 +1628,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateAndRegisterAllocator(IntPtr /*(OrtEnv*)*/ env, IntPtr /*(const OrtMemoryInfo*)*/ memInfo, - IntPtr/*(const OrtArenaCfg*)*/ arenaCfg); + IntPtr /*(const OrtArenaCfg*)*/ arenaCfg); public static DOrtCreateAndRegisterAllocator OrtCreateAndRegisterAllocator; @@ -1644,13 +1638,13 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// the source projected language [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSetLanguageProjection(IntPtr /* (OrtEnv*) */ environment, - int projection); + int projection); public static DOrtSetLanguageProjection OrtSetLanguageProjection; - #endregion IoBinding API +#endregion IoBinding API - #region ModelMetadata API +#region ModelMetadata API /// /// Gets the ModelMetadata associated with an InferenceSession @@ -1670,7 +1664,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// (output) producer name from the ModelMetadata instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetProducerName(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, - IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); + IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); public static DOrtModelMetadataGetProducerName OrtModelMetadataGetProducerName; @@ -1682,7 +1676,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// (output) graph name from the ModelMetadata instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetGraphName(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, - IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); + IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); public static DOrtModelMetadataGetGraphName OrtModelMetadataGetGraphName; @@ -1694,7 +1688,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// (output) domain from the ModelMetadata instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetDomain(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, - IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); + IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); public static DOrtModelMetadataGetDomain OrtModelMetadataGetDomain; @@ -1706,7 +1700,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// (output) description from the ModelMetadata instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetDescription(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, - IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); + IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); public static DOrtModelMetadataGetDescription OrtModelMetadataGetDescription; @@ -1718,7 +1712,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// (output) graph description from the ModelMetadata instance [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetGraphDescription(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, - IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); + IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); public static DOrtModelMetadataGetGraphDescription OrtModelMetadataGetGraphDescription; @@ -1742,7 +1736,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// (output) number of keys in the custom metadata map [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetCustomMetadataMapKeys(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, - IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char***) */ keys, out long /* (int64_t*) */ numKeys); + IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char***) */ keys, out long /* (int64_t*) */ numKeys); public static DOrtModelMetadataGetCustomMetadataMapKeys OrtModelMetadataGetCustomMetadataMapKeys; @@ -1755,7 +1749,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// (output) value for the key in the custom metadata map [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataLookupCustomMetadataMap(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, - IntPtr /* (OrtAllocator*) */ allocator, IntPtr /* (const char*) */ key, out IntPtr /* (char**) */ value); + IntPtr /* (OrtAllocator*) */ allocator, IntPtr /* (const char*) */ key, out IntPtr /* (char**) */ value); public static DOrtModelMetadataLookupCustomMetadataMap OrtModelMetadataLookupCustomMetadataMap; @@ -1768,9 +1762,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DOrtReleaseModelMetadata OrtReleaseModelMetadata; - #endregion ModelMetadata API +#endregion ModelMetadata API - #region OrtValue API +#region OrtValue API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtHasValue(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(int*)*/ hasValue); @@ -1779,9 +1773,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetValue(IntPtr /*(OrtValue*)*/ value, - int index, - IntPtr /*(OrtAllocator*)*/ allocator, - out IntPtr /*(OrtValue**)*/ outputValue); + int index, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(OrtValue**)*/ outputValue); public static DOrtGetValue OrtGetValue; @@ -1801,8 +1795,8 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DOrtGetValueCount OrtGetValueCount; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr/*(OrtStatus*)*/ DOrtCreateValue(IntPtr[] /* const OrtValue* const* in */ values, - UIntPtr /* size_t */ num_values, IntPtr /* (OnnxValueType */ onnxValueType, out IntPtr /* OrtValue** */ ortValue); + public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateValue(IntPtr[] /* const OrtValue* const* in */ values, + UIntPtr /* size_t */ num_values, IntPtr /* (OnnxValueType */ onnxValueType, out IntPtr /* OrtValue** */ ortValue); public static DOrtCreateValue OrtCreateValue; @@ -1813,23 +1807,23 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateTensorAsOrtValue( - IntPtr /*_Inout_ OrtAllocator* */ allocator, - long[] /*_In_ const int64_t* */ shape, - UIntPtr /*size_t*/ shape_len, - Tensors.TensorElementType type, - out IntPtr /* OrtValue** */ outputValue); + IntPtr /*_Inout_ OrtAllocator* */ allocator, + long[] /*_In_ const int64_t* */ shape, + UIntPtr /*size_t*/ shape_len, + Tensors.TensorElementType type, + out IntPtr /* OrtValue** */ outputValue); public static DOrtCreateTensorAsOrtValue OrtCreateTensorAsOrtValue; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus */ DOrtCreateTensorWithDataAsOrtValue( - IntPtr /* (const OrtMemoryInfo*) */ allocatorInfo, - IntPtr /* (void*) */dataBufferHandle, - UIntPtr dataLength, - long[] shape, - UIntPtr shapeLength, - Tensors.TensorElementType type, - out IntPtr /* OrtValue** */ outputValue); + IntPtr /* (const OrtMemoryInfo*) */ allocatorInfo, + IntPtr /* (void*) */ dataBufferHandle, + UIntPtr dataLength, + long[] shape, + UIntPtr shapeLength, + Tensors.TensorElementType type, + out IntPtr /* OrtValue** */ outputValue); public static DOrtCreateTensorWithDataAsOrtValue OrtCreateTensorWithDataAsOrtValue; @@ -1854,9 +1848,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// \param len total data length, not including the trailing '\0' chars. [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtFillStringTensor( - IntPtr /* OrtValue */ value, - IntPtr[] /* const char* const* */s, - UIntPtr /* size_t */ s_len); + IntPtr /* OrtValue */ value, + IntPtr[] /* const char* const* */ s, + UIntPtr /* size_t */ s_len); public static DOrtFillStringTensor OrtFillStringTensor; @@ -1876,7 +1870,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetStringTensorContent( IntPtr /*(OrtValue*)*/ value, - byte[] /*(void*)*/ dst_buffer, + byte[] /*(void*)*/ dst_buffer, UIntPtr dst_buffer_len, UIntPtr[] offsets, UIntPtr offsets_len); @@ -1913,7 +1907,7 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorTypeAndShape( - IntPtr /*(OrtValue*)*/ value, + IntPtr /*(OrtValue*)*/ value, out IntPtr /*(struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo); public static DOrtGetTensorTypeAndShape OrtGetTensorTypeAndShape; @@ -1932,35 +1926,35 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensionsCount( - IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, + IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, out UIntPtr output); public static DOrtGetDimensionsCount OrtGetDimensionsCount; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetDimensions( - IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, - long[] dim_values, - UIntPtr dim_values_length); + IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, + long[] dim_values, + UIntPtr dim_values_length); public static DOrtGetDimensions OrtGetDimensions; /** - * Get the symbolic dimension names for dimensions with a value of -1. - * Order and number of entries is the same as values returned by GetDimensions. - * The name may be empty for an unnamed symbolic dimension. - * e.g. - * If OrtGetDimensions returns [-1, -1, 2], OrtGetSymbolicDimensions would return an array with 3 entries. - * If the values returned were ['batch', '', ''] it would indicate that - * - the first dimension was a named symbolic dimension (-1 dim value and name in symbolic dimensions), - * - the second dimension was an unnamed symbolic dimension (-1 dim value and empty string), - * - the entry for the third dimension should be ignored as it is not a symbolic dimension (dim value >= 0). - */ + * Get the symbolic dimension names for dimensions with a value of -1. + * Order and number of entries is the same as values returned by GetDimensions. + * The name may be empty for an unnamed symbolic dimension. + * e.g. + * If OrtGetDimensions returns [-1, -1, 2], OrtGetSymbolicDimensions would return an array with 3 entries. + * If the values returned were ['batch', '', ''] it would indicate that + * - the first dimension was a named symbolic dimension (-1 dim value and name in symbolic dimensions), + * - the second dimension was an unnamed symbolic dimension (-1 dim value and empty string), + * - the entry for the third dimension should be ignored as it is not a symbolic dimension (dim value >= 0). + */ [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetSymbolicDimensions( - IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, - IntPtr[] dim_params, /* const char* values, converted to string by caller */ - UIntPtr dim_params_length); + IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, + IntPtr[] dim_params, /* const char* values, converted to string by caller */ + UIntPtr dim_params_length); public static DOrtGetSymbolicDimensions OrtGetSymbolicDimensions; @@ -1975,15 +1969,15 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca */ [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorShapeElementCount(IntPtr /*(const struct OrtTensorTypeAndShapeInfo*)*/ typeAndShapeInfo, - out UIntPtr /* size_t */ output); + out UIntPtr /* size_t */ output); public static DOrtGetTensorShapeElementCount OrtGetTensorShapeElementCount; - [UnmanagedFunctionPointer(CallingConvention.Winapi)] // The out ortMemoryInfo must not be destroyed/deallocated. The pointer points to an object owned by // the contained Tensor/SparseTensor. + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorMemoryInfo(IntPtr /* const OrtValue* */ ortValue, - out IntPtr /* const OrtMemoryInfo** */ ortMemoryInfo); + out IntPtr /* const OrtMemoryInfo** */ ortMemoryInfo); public static DOrtGetTensorMemoryInfo OrtGetTensorMemoryInfo; @@ -1993,10 +1987,12 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DCastTypeInfoToMapTypeInfo OrtCastTypeInfoToMapTypeInfo; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DGetMapKeyType(IntPtr /*const OrtMapTypeInfo* */ mapTypeInfo, out IntPtr /*(TensorElementType*)*/ tensorElementType); public static DGetMapKeyType OrtGetMapKeyType; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DGetMapValueType(IntPtr /* const OrtMapTypeInfo* */ map_type_info, out IntPtr /* OrtTypeInfo** */ type_info); public static DGetMapValueType OrtGetMapValueType; @@ -2007,13 +2003,14 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DCastTypeInfoToSequenceTypeInfo OrtCastTypeInfoToSequenceTypeInfo; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DGetSequenceElementType(IntPtr /* const OrtSequenceTypeInfo* */ sequenceTypeInfo, out IntPtr /* OrtTypeInfo** */ elementTypeInfo); public static DGetSequenceElementType OrtGetSequenceElementType; // OptionalTypeInfo [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DOrtCastTypeInfoToOptionalTypeInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /* const struct OrtOptionalTypeInfo** */ optionalTypeInfo); + public delegate IntPtr /*(OrtStatus*)*/ DOrtCastTypeInfoToOptionalTypeInfo(IntPtr /*(struct OrtTypeInfo*)*/ typeInfo, out IntPtr /* const struct OrtOptionalTypeInfo** */ optionalTypeInfo); public static DOrtCastTypeInfoToOptionalTypeInfo OrtCastTypeInfoToOptionalTypeInfo; @@ -2027,10 +2024,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DOrtReleaseValue OrtReleaseValue; - #endregion +#endregion - - #region Misc API +#region Misc API /// /// Queries all the execution providers supported in the native onnxruntime shared library @@ -2070,8 +2066,8 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DOrtReleasePrepackedWeightsContainer OrtReleasePrepackedWeightsContainer; - #endregion - } //class NativeMethods +#endregion + } // class NativeMethods // onnxruntime-extensions helpers to make usage simpler. // The onnxruntime-extensions nuget package containing the native library can be optionally added to the app. @@ -2092,7 +2088,5 @@ internal static class OrtExtensionsNativeMethods CallingConvention = CallingConvention.Winapi)] public static extern IntPtr /* OrtStatus* */ RegisterCustomOps(IntPtr /* OrtSessionOptions* */ sessionOptions, ref OrtApiBase /* OrtApiBase* */ ortApiBase); - - } -} //namespace +} // namespace diff --git a/csharp/tools/ValidateNativeDelegateAttributes.py b/csharp/tools/ValidateNativeDelegateAttributes.py new file mode 100644 index 0000000000000..acd6c173bfeb0 --- /dev/null +++ b/csharp/tools/ValidateNativeDelegateAttributes.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import pathlib + + +def check_all_delegates_have_unmanaged_function_pointer_attribute(file: pathlib.Path): + """ + Check that all 'public delegate' declarations have a matching UnmanagedFunctionPointer attribute. + :param file: C# source file to check. + :return: Number of errors + """ + + print(f"Checking {file!s}") + + errors = 0 + line_num = 0 + with open(str(file.resolve(strict=True))) as f: + prev_line = "" + for line in f.readlines(): + line_num += 1 + + # strip so it's easier to deal with commented out lines. + line = line.strip() # noqa + if line.startswith("public delegate ") and not prev_line.startswith("[UnmanagedFunctionPointer"): + errors += 1 + print(f"Line {line_num} is missing UnmanagedFunctionPointer attribute:\n\t{prev_line}\n\t{line}") + + prev_line = line + + return errors + + +def main(): + arg_parser = argparse.ArgumentParser( + "Script to validate that the native delegates for the ONNX Runtime C# managed projects have the required " + "attributes for iOS AOT. Paths are inferred from the script location." + "Errors of this nature can only be detected at runtime, in a release build, of a Xamarin/MAUI app, " + "on an actual iOS device. Due to that we take extra steps to identify problems early." + ) + + # no real args. just using this to provide description as help message + _ = arg_parser.parse_args() + + # CI needs resolve() as __file__ is a relative path when the script is run there + script_dir = pathlib.Path(__file__).resolve().parent + csharp_root = script_dir.parent + + managed_dir = csharp_root / "src" / "Microsoft.ML.OnnxRuntime" + native_methods = managed_dir / "NativeMethods.shared.cs" + training_native_methods = managed_dir / "Training" / "NativeTrainingMethods.shared.cs" + errors = check_all_delegates_have_unmanaged_function_pointer_attribute(native_methods) + errors += check_all_delegates_have_unmanaged_function_pointer_attribute(training_native_methods) + + if errors: + raise ValueError(f"{errors} errors were found. Please check output for specifics.") + + +if __name__ == "__main__": + main() diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 169737c9138c3..c73f978bdf404 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -54,6 +54,7 @@ Do not modify directly.* * com.microsoft.MatMulIntegerToFloat * com.microsoft.MatMulNBits * com.microsoft.MaxpoolWithMask + * com.microsoft.MoE * com.microsoft.MulInteger * com.microsoft.MultiHeadAttention * com.microsoft.MurmurHash3 @@ -2384,7 +2385,7 @@ This version of the operator has been available since version 1 of the 'com.micr Group Query Self/Cross Attention. - Supports different number of heads for q and kv. + Supports different number of heads for q and kv. Only supports causal or local attention. #### Version @@ -2395,6 +2396,8 @@ This version of the operator has been available since version 1 of the 'com.micr
kv_num_heads : int (required)
Number of attention heads for k and v
+
local_window_size : int
+
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
num_heads : int (required)
Number of attention heads for q
scale : float
@@ -2646,8 +2649,8 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T1 : tensor(float), tensor(float16)
-
Constrain input and output types to float/half_float tensors.
+
T1 : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float/half_float/brain_float tensors.
T2 : tensor(uint8)
Constrain quantized weight types to uint8.
@@ -2904,6 +2907,58 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.MoE** + + Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1, + GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, and Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) + usually uses top 32 experts. + + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation_type : string
+
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+
k : int
+
Number of top experts to select from expert pool
+
+ +#### Inputs (4 - 6) + +
+
input : T
+
2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)
+
router_probs : T
+
2D input tensor with shape (num_rows, num_experts)
+
fc1_experts_weights : T
+
3D input tensor with shape (num_experts, hidden_size, inter_size)
+
fc2_experts_weights : T
+
3D input tensor with shape (num_experts, inter_size, hidden_size)
+
fc1_experts_bias (optional) : T
+
2D optional input tensor with shape (num_experts, inter_size)
+
fc2_experts_bias (optional) : T
+
2D optional input tensor with shape (num_experts, hidden_size)
+
+ +#### Outputs + +
+
output : T
+
2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16)
+
Constrain input and output types to float or float16 tensors.
+
+ + ### **com.microsoft.MulInteger** Performs element-wise binary quantized multiplication (with Numpy-style broadcasting support). @@ -4968,7 +5023,7 @@ This version of the operator has been available since version 1 of the 'com.micr
input : T
-
3D tensor with shape (batch_size, sequence_length, hidden_size)
+
3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)
position_ids : M
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
cos_cache : T
@@ -4981,7 +5036,7 @@ This version of the operator has been available since version 1 of the 'com.micr
output : T
-
3D tensor with shape (batch_size, sequence_length, hidden_size)
+
tensor with same shape as input.
#### Type Constraints diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 12733c3551704..7fa89cca381d9 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -269,6 +269,15 @@ data sparsity based performance optimizations. unset ORTMODULE_CACHE_DIR # Disable ``` +#### ORTMODULE_USE_EFFICIENT_ATTENTION + +- **Feature Area**: *ORTMODULE/Optimizations* +- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and falling back to PyTorch's efficient_attention ATen kernel for execution. NOTE that it requires torch's version is 2.1.1 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually. + + ```bash + export ORTMODULE_USE_EFFICIENT_ATTENTION=1 + ``` + ### 2.2 Memory Optimization Q: *Want to run a bigger batch size?* @@ -397,6 +406,15 @@ Check [FP16_Optimizer implementation](../orttraining/orttraining/python/training export ORTMODULE_TUNING_RESULTS_PATH=/tmp/tuning_results ``` +#### ORTMODULE_USE_FLASH_ATTENTION + +- **Feature Area**: *ORTMODULE/TritonOp* +- **Description**: By default, this is disabled. This env var can be used for enabling attention fusion and using Flash Attention's Triton version as the kernel. NOTE that it requires ORTMODULE_USE_TRITON to be enabled, and CUDA device capability is 8.0 or above. There are some build-in patterns for attention fusion, if none of the patterns works for your model, you can add a custom one in your user script manually. + + ```bash + export ORTMODULE_USE_FLASH_ATTENTION=1 + ``` + #### ORTMODULE_TRITON_DEBUG - **Feature Area**: *ORTMODULE/TritonOp* diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 8e546b30aa4cb..16df788c284ee 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -840,8 +840,9 @@ Do not modify directly.* |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| -|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc2_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/include/onnxruntime/core/framework/op_node_proto_helper.h b/include/onnxruntime/core/framework/op_node_proto_helper.h index 700e1edc0cb7d..e7ac01947af41 100644 --- a/include/onnxruntime/core/framework/op_node_proto_helper.h +++ b/include/onnxruntime/core/framework/op_node_proto_helper.h @@ -10,20 +10,6 @@ #include "core/common/gsl.h" #endif -#ifdef __has_attribute -#define ORT_HAVE_ATTRIBUTE(x) __has_attribute(x) -#else -#define ORT_HAVE_ATTRIBUTE(x) 0 -#endif - -#if ORT_HAVE_ATTRIBUTE(nodiscard) -#define MUST_USE_RESULT [[nodiscard]] -#elif defined(__clang__) && ORT_HAVE_ATTRIBUTE(warn_unused_result) -#define MUST_USE_RESULT __attribute__((warn_unused_result)) -#else -#define MUST_USE_RESULT -#endif - class IMLOpKernel; namespace onnxruntime { @@ -43,14 +29,26 @@ class OpNodeProtoHelper { Call this function for a required attribute or when a default value for an optional attribute is specified in the op schema */ template - MUST_USE_RESULT Status GetAttr(const std::string& name, T* value) const; + Status GetAttr(const std::string& name, T* value) const; + + /** + Get a single attribute + Call this function for a required attribute or when a default value for an optional attribute is specified in the op schema + Throws if an attribute with the specified type doesn't exist + */ + template + [[nodiscard]] T GetAttr(const std::string& name) const { + T value; + ORT_THROW_IF_ERROR(GetAttr(name, &value)); + return value; + } /** Get a single attribute Call this function only when a default value for an optional attribute isn't specified in the op schema */ template - T GetAttrOrDefault(const std::string& name, const T& default_value) const { + [[nodiscard]] T GetAttrOrDefault(const std::string& name, const T& default_value) const { T tmp; return GetAttr(name, &tmp).IsOK() ? tmp : default_value; } @@ -70,7 +68,8 @@ class OpNodeProtoHelper { Call this function only when a default value for an optional attribute isn't specified in the op schema */ template - MUST_USE_RESULT std::vector GetAttrsOrDefault(const std::string& name, const std::vector& default_value = std::vector{}) const { + [[nodiscard]] std::vector GetAttrsOrDefault(const std::string& name, + const std::vector& default_value = {}) const { std::vector tmp; return GetAttrs(name, tmp).IsOK() ? tmp : default_value; } @@ -87,11 +86,12 @@ class OpNodeProtoHelper { /// Attribute data in a span, out parameter /// Status template - MUST_USE_RESULT Status GetAttrsAsSpan(const std::string& name, gsl::span& values) const; + Status GetAttrsAsSpan(const std::string& name, gsl::span& values) const; - MUST_USE_RESULT Status GetAttrs(const std::string& name, TensorShapeVector& out) const; + Status GetAttrs(const std::string& name, TensorShapeVector& out) const; - MUST_USE_RESULT TensorShapeVector GetAttrsOrDefault(const std::string& name, const TensorShapeVector& default_value = TensorShapeVector{}) const { + [[nodiscard]] TensorShapeVector GetAttrsOrDefault(const std::string& name, + const TensorShapeVector& default_value = {}) const { TensorShapeVector tmp; return GetAttrs(name, tmp).IsOK() ? tmp : default_value; } @@ -100,43 +100,43 @@ class OpNodeProtoHelper { Get repeated attributes */ template - MUST_USE_RESULT Status GetAttrs(const std::string& name, std::vector& values) const; + Status GetAttrs(const std::string& name, std::vector& values) const; template - MUST_USE_RESULT Status GetAttrs(const std::string& name, gsl::span values) const; + Status GetAttrs(const std::string& name, gsl::span values) const; - MUST_USE_RESULT Status GetAttrsStringRefs(const std::string& name, - std::vector>& refs) const; + Status GetAttrsStringRefs(const std::string& name, + std::vector>& refs) const; - uint32_t GetPrimitiveAttrElementCount(ONNX_NAMESPACE::AttributeProto_AttributeType type, - const std::string& name) const noexcept; + [[nodiscard]] uint32_t GetPrimitiveAttrElementCount(ONNX_NAMESPACE::AttributeProto_AttributeType type, + const std::string& name) const noexcept; - bool HasPrimitiveAttribute(ONNX_NAMESPACE::AttributeProto_AttributeType type, - const std::string& name) const noexcept; + [[nodiscard]] bool HasPrimitiveAttribute(ONNX_NAMESPACE::AttributeProto_AttributeType type, + const std::string& name) const noexcept; - uint32_t GetInputCount() const { + [[nodiscard]] uint32_t GetInputCount() const { return gsl::narrow_cast(impl_->getNumInputs()); } - uint32_t GetOutputCount() const { + [[nodiscard]] uint32_t GetOutputCount() const { return gsl::narrow_cast(impl_->getNumOutputs()); } - const ONNX_NAMESPACE::TypeProto* GetInputType(size_t index) const { + [[nodiscard]] const ONNX_NAMESPACE::TypeProto* GetInputType(size_t index) const { return impl_->getInputType(index); } - const ONNX_NAMESPACE::TypeProto* GetOutputType(size_t index) const { + [[nodiscard]] const ONNX_NAMESPACE::TypeProto* GetOutputType(size_t index) const { // Work around lack of a const method from the onnx InferenceContext interface return const_cast(impl_)->getOutputType(index); } // Try to query an attribute, returning nullptr if it doesn't exist - const ONNX_NAMESPACE::AttributeProto* TryGetAttribute(const std::string& name) const { + [[nodiscard]] const ONNX_NAMESPACE::AttributeProto* TryGetAttribute(const std::string& name) const { return impl_->getAttribute(name); } - const ONNX_NAMESPACE::AttributeProto* GetAttribute(const std::string& name) const { + [[nodiscard]] const ONNX_NAMESPACE::AttributeProto* GetAttribute(const std::string& name) const { const ONNX_NAMESPACE::AttributeProto* attr = TryGetAttribute(name); ORT_ENFORCE(attr != nullptr); return attr; diff --git a/include/onnxruntime/core/framework/tensor_shape.h b/include/onnxruntime/core/framework/tensor_shape.h index b3783696b8d78..82a1c1de83523 100644 --- a/include/onnxruntime/core/framework/tensor_shape.h +++ b/include/onnxruntime/core/framework/tensor_shape.h @@ -2,34 +2,17 @@ // Licensed under the MIT License. #pragma once -#include -#include + #include -#include #include -#include "core/common/gsl.h" -#include "onnxruntime_config.h" - -#ifndef DISABLE_ABSEIL -// Need to include abseil inlined_vector.h header directly here -// as hash tables cause CUDA 10.2 compilers to fail. inlined_vector.h is fine. -#ifdef _MSC_VER -#pragma warning(push) -// C4127: conditional expression is constant -#pragma warning(disable : 4127) -// C4324: structure was padded due to alignment specifier -// Usage of alignas causes some internal padding in places. -#pragma warning(disable : 4324) -#endif - -#include - -#ifdef _MSC_VER -#pragma warning(pop) -#endif -#endif // DISABLE_ABSEIL +#include +#include +#include +#include "core/common/gsl.h" +#include "core/common/inlined_containers_fwd.h" #include "core/common/span_utils.h" +#include "onnxruntime_config.h" namespace onnxruntime { #ifdef __GNUC__ @@ -41,18 +24,10 @@ namespace onnxruntime { constexpr size_t kTensorShapeSmallBufferElementsSize = 5; -#ifndef DISABLE_ABSEIL // Use this type to build a shape and then create TensorShape. -using TensorShapeVector = absl::InlinedVector; -#else -class TensorShapeVector : public std::vector { - using Base = std::vector; - - public: - using Base::Base; -}; - -#endif // DISABLE_ABSEIL +// We opt to re-use a common instantiation instead of a typedef with kTensorShapeSmallBufferElementsSize +// To reduce on binary size. +using TensorShapeVector = InlinedVector; inline TensorShapeVector ToShapeVector(const gsl::span& span) { TensorShapeVector out; @@ -194,9 +169,7 @@ class TensorShape { friend struct ProviderHostImpl; // So that the shared provider interface can access Allocate }; -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif + // operator<< to nicely output to a stream std::ostream& operator<<(std::ostream& out, const TensorShape& shape); diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 462d410e13769..fe0734c51f807 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -397,6 +397,10 @@ class Node { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) /** Remove the specified attribute from this Node */ bool ClearAttribute(const std::string& attr_name); + + /** Gets the Node's mutable attributes. */ + NodeAttributes& GetMutableAttributes() noexcept { return attributes_; } + #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) /** @@ -406,8 +410,6 @@ class Node { int PruneRemovableAttributes(gsl::span removable_attributes); #if !defined(ORT_MINIMAL_BUILD) - /** Gets the Node's mutable attributes. */ - NodeAttributes& GetMutableAttributes() noexcept { return attributes_; } /** Gets the Graph instance that is instantiated from a GraphProto attribute during Graph::Resolve. @param attr_name Attribute name for the GraphProto attribute. @@ -441,6 +443,13 @@ class Node { return attr_to_subgraph_map_; } + /** Gets a map of attribute name to the mutable Graph instances for all subgraphs of the Node. + * @returns a mutable map of mutable subgraphs. + */ + std::unordered_map>& GetMutableMapOfAttributeNameToSubgraph() { + return attr_to_subgraph_map_; + } + /** Gets a map of attribute name to the const Graph instances for all subgraphs of the Node. @returns Map of the attribute name that defines the subgraph to the subgraph's Graph instance. nullptr if the Node has no subgraphs. @@ -586,7 +595,7 @@ class Node { // create a Graph instance for an attribute that contains a GraphProto void CreateSubgraph(const std::string& attr_name); - const std::vector>& MutableSubgraphs() noexcept { return subgraphs_; } + std::vector>& MutableSubgraphs() noexcept { return subgraphs_; } // validate and update the input arg count common::Status UpdateInputArgCount(); @@ -1134,6 +1143,26 @@ class Graph { */ Node& FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name); + /** + Directly insert one of the If node branches into this Graph. + `If` node condition must be a constant. The function would + rename the nodes of the corresponding subgraph to make sure there is no conflict. + + Explicit and implicit inputs references stay the same. + + All of the outputs of the subgraph being inlined should be renamed + to the outputs of the If node. + + The function will process any subgraphs in each of the nodes being inlined, + and will rename any references to the new names introduced. + + @param condition_value If condition value + @param if_node - the node that contains the graph_to_inline. This node is going + to be deleted and replaced by the corresponding graph (either then or else) + @param logger + */ + Status InlineIfSubgraph(bool condition_value, Node& if_node, const logging::Logger& logger); + /** Directly insert the nodes in the function Node provided into this Graph. The Graph needs to be Resolve()d after this call. diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 646f33ed952a4..d73d551920d47 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -1,5 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. + +// This header is to expose a context for cuda custom ops. +// By the context, a custom cuda operator could fetch existing resources, +// such as cuda stream and cudnn handle, for reusing. + +// For concrete usage, pls find page here: +// https://onnxruntime.ai/docs/reference/operators/add-custom-op.html#custom-ops-for-cuda-and-rocm + #pragma once #define ORT_CUDA_CTX @@ -21,7 +29,7 @@ struct CudaContext : public CustomOpContext { cublasHandle_t cublas_handle = {}; OrtAllocator* deferred_cpu_allocator = {}; - void Init(const OrtKernelContext& kernel_ctx) override { + void Init(const OrtKernelContext& kernel_ctx) { const auto& ort_api = Ort::GetApi(); void* resource = {}; OrtStatus* status = nullptr; diff --git a/include/onnxruntime/core/providers/custom_op_context.h b/include/onnxruntime/core/providers/custom_op_context.h index 547f9a90aff85..8f3d2476d4fdb 100644 --- a/include/onnxruntime/core/providers/custom_op_context.h +++ b/include/onnxruntime/core/providers/custom_op_context.h @@ -3,11 +3,8 @@ #pragma once -#include - // CustomOpContext defines an interface allowing a custom op to access ep-specific resources. struct CustomOpContext { CustomOpContext() = default; virtual ~CustomOpContext(){}; - virtual void Init(const OrtKernelContext&){}; }; \ No newline at end of file diff --git a/include/onnxruntime/core/providers/dml/dml_provider_factory.h b/include/onnxruntime/core/providers/dml/dml_provider_factory.h index cf3ddc3f125f9..7d7f05193f486 100644 --- a/include/onnxruntime/core/providers/dml/dml_provider_factory.h +++ b/include/onnxruntime/core/providers/dml/dml_provider_factory.h @@ -37,9 +37,13 @@ enum OrtDmlPerformancePreference { }; enum OrtDmlDeviceFilter : uint32_t { +#ifdef ENABLE_NPU_ADAPTER_ENUMERATION Any = 0xffffffff, Gpu = 1 << 0, Npu = 1 << 1, +#else + Gpu = 1 << 0, +#endif }; inline OrtDmlDeviceFilter operator~(OrtDmlDeviceFilter a) { return (OrtDmlDeviceFilter) ~(int)a; } diff --git a/include/onnxruntime/core/providers/rocm/rocm_context.h b/include/onnxruntime/core/providers/rocm/rocm_context.h index ff62094f3f439..5f04289a8c6e0 100644 --- a/include/onnxruntime/core/providers/rocm/rocm_context.h +++ b/include/onnxruntime/core/providers/rocm/rocm_context.h @@ -18,7 +18,7 @@ struct RocmContext : public CustomOpContext { miopenHandle_t miopen_handle = {}; rocblas_handle rblas_handle = {}; - void Init(const OrtKernelContext& kernel_ctx) override { + void Init(const OrtKernelContext& kernel_ctx) { const auto& ort_api = Ort::GetApi(); void* resource = {}; OrtStatus* status = nullptr; diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h index 443710884743a..0c0af16d4e20c 100644 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -399,6 +399,15 @@ struct TensorArray : public ArgBase { using Variadic = TensorArray; +/* +Note: +OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core. +The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so: +1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierachy. +2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp, + hence memory could still be recycled properly. +Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety. +*/ struct OrtLiteCustomOp : public OrtCustomOp { using ConstOptionalFloatTensor = std::optional&>; using OptionalFloatTensor = std::optional>; @@ -774,10 +783,13 @@ struct OrtLiteCustomOp : public OrtCustomOp { OrtLiteCustomOp(const char* op_name, const char* execution_provider, - int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), - execution_provider_(execution_provider), - start_ver_(start_ver), - end_ver_(end_ver) { + ShapeInferFn shape_infer_fn, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), + execution_provider_(execution_provider), + shape_infer_fn_(shape_infer_fn), + start_ver_(start_ver), + end_ver_(end_ver) { OrtCustomOp::version = ORT_API_VERSION; OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast(op)->op_name_.c_str(); }; @@ -858,8 +870,13 @@ struct OrtLiteCustomOp : public OrtCustomOp { std::vector input_types_; std::vector output_types_; + ShapeInferFn shape_infer_fn_ = {}; + int start_ver_ = 1; int end_ver_ = MAX_CUSTOM_OP_END_VER; + + void* compute_fn_ = {}; + void* compute_fn_return_status_ = {}; }; //////////////////////////// OrtLiteCustomFunc //////////////////////////////// @@ -891,9 +908,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { ComputeFn compute_fn, ShapeInferFn shape_infer_fn = {}, int start_ver = 1, - int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), - compute_fn_(compute_fn), - shape_infer_fn_(shape_infer_fn) { + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) { + compute_fn_ = reinterpret_cast(compute_fn); ParseArgs(input_types_, output_types_); OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { @@ -905,7 +921,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique(); - kernel->compute_fn_ = static_cast(this_)->compute_fn_; + auto me = static_cast(this_); + kernel->compute_fn_ = reinterpret_cast(me->compute_fn_); Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); auto self = static_cast(this_); @@ -931,9 +948,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { ComputeFnReturnStatus compute_fn_return_status, ShapeInferFn shape_infer_fn = {}, int start_ver = 1, - int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), - compute_fn_return_status_(compute_fn_return_status), - shape_infer_fn_(shape_infer_fn) { + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) { + compute_fn_return_status_ = reinterpret_cast(compute_fn_return_status); ParseArgs(input_types_, output_types_); OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { @@ -945,7 +961,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique(); - kernel->compute_fn_return_status_ = static_cast(this_)->compute_fn_return_status_; + auto me = static_cast(this_); + kernel->compute_fn_return_status_ = reinterpret_cast(me->compute_fn_return_status_); Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); auto self = static_cast(this_); @@ -965,10 +982,6 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { }; } } - - ComputeFn compute_fn_ = {}; - ComputeFnReturnStatus compute_fn_return_status_ = {}; - ShapeInferFn shape_infer_fn_ = {}; }; // struct OrtLiteCustomFunc /////////////////////////// OrtLiteCustomStruct /////////////////////////// @@ -1007,7 +1020,7 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { OrtLiteCustomStruct(const char* op_name, const char* execution_provider, int start_ver = 1, - int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver) { + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) { SetCompute(&CustomOp::Compute); OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { diff --git a/js/.eslintrc.js b/js/.eslintrc.js index 56d2d45ee5d4f..0bf47c5264f61 100644 --- a/js/.eslintrc.js +++ b/js/.eslintrc.js @@ -5,10 +5,18 @@ module.exports = { root: true, - ignorePatterns: ['**/*.js', 'ort-schema/', 'common/test/type-tests/', 'test/data/', 'node_modules/', 'dist/'], + ignorePatterns: [ + '**/*.js', + 'node_modules/', + 'ort-schema/', + 'common/test/type-tests/', + 'web/types.d.ts', + 'test/data/', + 'dist/', + ], env: { 'es6': true }, parser: '@typescript-eslint/parser', - parserOptions: { 'project': 'tsconfig.json', 'sourceType': 'module' }, + parserOptions: { 'project': true, 'sourceType': 'module' }, plugins: ['@typescript-eslint', 'prefer-arrow', 'header', 'import', 'unicorn', 'jsdoc'], rules: { 'unicorn/filename-case': 'error', @@ -145,7 +153,54 @@ module.exports = { } }, { files: ['web/lib/**/*.ts'], rules: { - 'no-underscore-dangle': 'off', + 'no-underscore-dangle': ['error', { + 'allow': [ + '_free', + '_malloc', + '_JsepGetNodeName', + '_JsepOutput', + '_OrtAddFreeDimensionOverride', + '_OrtAddRunConfigEntry', + '_OrtAddSessionConfigEntry', + '_OrtAppendExecutionProvider', + '_OrtBindInput', + '_OrtBindOutput', + '_OrtClearBoundOutputs', + '_OrtCreateBinding', + '_OrtCreateRunOptions', + '_OrtCreateSession', + '_OrtCreateSessionOptions', + '_OrtCreateTensor', + '_OrtEndProfiling', + '_OrtFree', + '_OrtGetInputName', + '_OrtGetInputOutputCount', + '_OrtGetLastError', + '_OrtGetOutputName', + '_OrtGetTensorData', + '_OrtInit', + '_OrtReleaseBinding', + '_OrtReleaseRunOptions', + '_OrtReleaseSession', + '_OrtReleaseSessionOptions', + '_OrtReleaseTensor', + '_OrtRun', + '_OrtRunWithBinding', + '_OrtTrainingCopyParametersFromBuffer', + '_OrtTrainingCopyParametersToBuffer', + '_OrtTrainingCreateSession', + '_OrtTrainingEvalStep', + '_OrtTrainingGetModelInputOutputCount', + '_OrtTrainingGetModelInputOutputName', + '_OrtTrainingGetParametersSize', + '_OrtTrainingLazyResetGrad', + '_OrtTrainingLoadCheckpoint', + '_OrtTrainingOptimizerStep', + '_OrtTrainingReleaseCheckpoint', + '_OrtTrainingReleaseSession', + '_OrtTrainingRunTrainStep' + ] + }] } }, { files: ['web/lib/onnxjs/**/*.ts'], rules: { @@ -158,6 +213,7 @@ module.exports = { 'import/no-internal-modules': 'off', 'prefer-arrow/prefer-arrow-functions': 'off', 'no-param-reassign': 'off', + 'no-underscore-dangle': 'off', 'guard-for-in': 'off' } }, { diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 8c1e69a68ca7e..c7760692eed00 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -241,6 +241,7 @@ export declare namespace InferenceSession { export interface WebNNExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webnn'; deviceType?: 'cpu'|'gpu'; + numThreads?: number; powerPreference?: 'default'|'low-power'|'high-performance'; } export interface CoreMLExecutionProviderOption extends ExecutionProviderOption { diff --git a/js/common/tsconfig.json b/js/common/tsconfig.json index 4751cd7564f5d..d7bb3a593f66f 100644 --- a/js/common/tsconfig.json +++ b/js/common/tsconfig.json @@ -4,8 +4,7 @@ "outDir": "./dist/esm", "declaration": true, "declarationMap": true, - "esModuleInterop": false, - "noUnusedParameters": true + "esModuleInterop": false }, "include": ["lib"] } diff --git a/js/node/package-lock.json b/js/node/package-lock.json index ce390aa88c0aa..c1cf8af4bb80e 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -22,7 +22,7 @@ "jsonc": "^2.0.0", "minimist": "^1.2.8", "node-addon-api": "^6.0.0", - "onnx-proto": "^8.0.1" + "protobufjs": "^7.2.4" } }, "../common": { @@ -97,12 +97,6 @@ "integrity": "sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==", "dev": true }, - "node_modules/@types/long": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/@types/long/-/long-4.0.2.tgz", - "integrity": "sha512-MqTGEo5bj5t157U6fA/BiDynNkn0YknVdh48CMPkTSpFTVmvao5UQmm7uEF6xBEo7qIMAlY/JSleYaE6VOdpaA==", - "dev": true - }, "node_modules/@types/minimist": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/@types/minimist/-/minimist-1.2.2.tgz", @@ -165,9 +159,9 @@ "dev": true }, "node_modules/axios": { - "version": "1.3.4", - "resolved": "https://registry.npmjs.org/axios/-/axios-1.3.4.tgz", - "integrity": "sha512-toYm+Bsyl6VC5wSkfkbbNB6ROv7KY93PEBBL6xyDczaIHasAiv4wPqQ/c4RjoQzipxRD2W5g21cOqQulZ7rHwQ==", + "version": "1.6.1", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz", + "integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==", "dev": true, "dependencies": { "follow-redirects": "^1.15.0", @@ -528,9 +522,9 @@ "dev": true }, "node_modules/long": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz", - "integrity": "sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", + "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==", "dev": true }, "node_modules/lru-cache": { @@ -663,15 +657,6 @@ "node": "^12.13.0 || ^14.15.0 || >=16.0.0" } }, - "node_modules/onnx-proto": { - "version": "8.0.1", - "resolved": "https://registry.npmjs.org/onnx-proto/-/onnx-proto-8.0.1.tgz", - "integrity": "sha512-ZpPTqp5dneh2bvavk/QpDsf20JJRArjqTkiMfshGmxR8ocjmfTk80fkW00FwLO7qRtybo9NPugcWQrumHYctLQ==", - "dev": true, - "dependencies": { - "protobufjs": "^6.11.2" - } - }, "node_modules/onnxruntime-common": { "resolved": "../common", "link": true @@ -690,9 +675,9 @@ } }, "node_modules/protobufjs": { - "version": "6.11.4", - "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-6.11.4.tgz", - "integrity": "sha512-5kQWPaJHi1WoCpjTGszzQ32PG2F4+wRY6BmAT4Vfw56Q2FZ4YZzK20xUYQH4YkfehY1e6QSICrJquM6xXZNcrw==", + "version": "7.2.5", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.5.tgz", + "integrity": "sha512-gGXRSXvxQ7UiPgfw8gevrfRWcTlSbOFg+p/N+JVJEK5VhueL2miT6qTymqAmjr1Q5WbOCyJbyrk6JfWKwlFn6A==", "dev": true, "hasInstallScript": true, "dependencies": { @@ -706,13 +691,11 @@ "@protobufjs/path": "^1.1.2", "@protobufjs/pool": "^1.1.0", "@protobufjs/utf8": "^1.1.0", - "@types/long": "^4.0.1", "@types/node": ">=13.7.0", - "long": "^4.0.0" + "long": "^5.0.0" }, - "bin": { - "pbjs": "bin/pbjs", - "pbts": "bin/pbts" + "engines": { + "node": ">=12.0.0" } }, "node_modules/proxy-from-env": { @@ -789,9 +772,9 @@ ] }, "node_modules/semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "dependencies": { "lru-cache": "^6.0.0" @@ -1070,12 +1053,6 @@ "integrity": "sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==", "dev": true }, - "@types/long": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/@types/long/-/long-4.0.2.tgz", - "integrity": "sha512-MqTGEo5bj5t157U6fA/BiDynNkn0YknVdh48CMPkTSpFTVmvao5UQmm7uEF6xBEo7qIMAlY/JSleYaE6VOdpaA==", - "dev": true - }, "@types/minimist": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/@types/minimist/-/minimist-1.2.2.tgz", @@ -1126,9 +1103,9 @@ "dev": true }, "axios": { - "version": "1.3.4", - "resolved": "https://registry.npmjs.org/axios/-/axios-1.3.4.tgz", - "integrity": "sha512-toYm+Bsyl6VC5wSkfkbbNB6ROv7KY93PEBBL6xyDczaIHasAiv4wPqQ/c4RjoQzipxRD2W5g21cOqQulZ7rHwQ==", + "version": "1.6.1", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz", + "integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==", "dev": true, "requires": { "follow-redirects": "^1.15.0", @@ -1413,9 +1390,9 @@ "dev": true }, "long": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz", - "integrity": "sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", + "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==", "dev": true }, "lru-cache": { @@ -1523,15 +1500,6 @@ "set-blocking": "^2.0.0" } }, - "onnx-proto": { - "version": "8.0.1", - "resolved": "https://registry.npmjs.org/onnx-proto/-/onnx-proto-8.0.1.tgz", - "integrity": "sha512-ZpPTqp5dneh2bvavk/QpDsf20JJRArjqTkiMfshGmxR8ocjmfTk80fkW00FwLO7qRtybo9NPugcWQrumHYctLQ==", - "dev": true, - "requires": { - "protobufjs": "^6.11.2" - } - }, "onnxruntime-common": { "version": "file:../common", "requires": { @@ -1549,9 +1517,9 @@ } }, "protobufjs": { - "version": "6.11.4", - "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-6.11.4.tgz", - "integrity": "sha512-5kQWPaJHi1WoCpjTGszzQ32PG2F4+wRY6BmAT4Vfw56Q2FZ4YZzK20xUYQH4YkfehY1e6QSICrJquM6xXZNcrw==", + "version": "7.2.5", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.5.tgz", + "integrity": "sha512-gGXRSXvxQ7UiPgfw8gevrfRWcTlSbOFg+p/N+JVJEK5VhueL2miT6qTymqAmjr1Q5WbOCyJbyrk6JfWKwlFn6A==", "dev": true, "requires": { "@protobufjs/aspromise": "^1.1.2", @@ -1564,9 +1532,8 @@ "@protobufjs/path": "^1.1.2", "@protobufjs/pool": "^1.1.0", "@protobufjs/utf8": "^1.1.0", - "@types/long": "^4.0.1", "@types/node": ">=13.7.0", - "long": "^4.0.0" + "long": "^5.0.0" } }, "proxy-from-env": { @@ -1619,9 +1586,9 @@ "dev": true }, "semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "requires": { "lru-cache": "^6.0.0" diff --git a/js/node/package.json b/js/node/package.json index 0f8f0e9d2260c..8e591d8f46b9d 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -19,6 +19,7 @@ }, "scripts": { "buildr": "tsc && node ./script/build --config=RelWithDebInfo", + "preprepare": "node -e \"require('node:fs').copyFileSync('./node_modules/long/index.d.ts', './node_modules/long/umd/index.d.ts')\"", "prepare": "tsc --build script test .", "rebuild": "tsc && node ./script/build --rebuild", "rebuildd": "tsc && node ./script/build --rebuild --config=Debug", @@ -39,7 +40,7 @@ "jsonc": "^2.0.0", "minimist": "^1.2.8", "node-addon-api": "^6.0.0", - "onnx-proto": "^8.0.1" + "protobufjs": "^7.2.4" }, "main": "dist/index.js", "os": [ diff --git a/js/node/test/ort-schema/protobuf/.gitignore b/js/node/test/ort-schema/protobuf/.gitignore new file mode 100644 index 0000000000000..092bb6c1c9fb4 --- /dev/null +++ b/js/node/test/ort-schema/protobuf/.gitignore @@ -0,0 +1,2 @@ +!onnx.js +!onnx.d.ts diff --git a/js/node/test/ort-schema/protobuf/README.md b/js/node/test/ort-schema/protobuf/README.md new file mode 100644 index 0000000000000..f5f52c602f1ad --- /dev/null +++ b/js/node/test/ort-schema/protobuf/README.md @@ -0,0 +1,21 @@ +# ONNX protobuf + +This directory contains generated protobuf definition for onnx: + +- onnx.js +- onnx.d.ts + +These files are generated from [a fork of onnx-proto](https://github.com/fs-eire/onnx-proto/tree/update-v9). + +The ONNX protobuf uses protobufjs@7.2.4, which depends on long@5.2.3, the version contains 2 bugs: + +- type export does not work with commonjs. described in https://github.com/dcodeIO/long.js/pull/124. added a "postinstall" script to fix. +- in the generated typescript declaration file 'onnx.d.ts', the following line: + ```ts + import Long = require("long"); + ``` + need to be replaced to fix type import error: + ```ts + import Long from "long"; + ``` + this replacement is done and code format is also applied to file 'onnx.d.ts'. diff --git a/js/node/test/ort-schema/protobuf/onnx.d.ts b/js/node/test/ort-schema/protobuf/onnx.d.ts new file mode 100644 index 0000000000000..c60264dca2a8d --- /dev/null +++ b/js/node/test/ort-schema/protobuf/onnx.d.ts @@ -0,0 +1,2627 @@ +import Long from 'long'; +import * as $protobuf from 'protobufjs'; + +/** Namespace onnx. */ +export namespace onnx { + + /** Version enum. */ + enum Version { + _START_VERSION = 0, + IR_VERSION_2017_10_10 = 1, + IR_VERSION_2017_10_30 = 2, + IR_VERSION_2017_11_3 = 3, + IR_VERSION_2019_1_22 = 4, + IR_VERSION_2019_3_18 = 5, + IR_VERSION_2019_9_19 = 6, + IR_VERSION_2020_5_8 = 7, + IR_VERSION_2021_7_30 = 8, + IR_VERSION = 9 + } + + /** Properties of an AttributeProto. */ + interface IAttributeProto { + /** AttributeProto name */ + name?: (string|null); + + /** AttributeProto refAttrName */ + refAttrName?: (string|null); + + /** AttributeProto docString */ + docString?: (string|null); + + /** AttributeProto type */ + type?: (onnx.AttributeProto.AttributeType|null); + + /** AttributeProto f */ + f?: (number|null); + + /** AttributeProto i */ + i?: (number|Long|null); + + /** AttributeProto s */ + s?: (Uint8Array|null); + + /** AttributeProto t */ + t?: (onnx.ITensorProto|null); + + /** AttributeProto g */ + g?: (onnx.IGraphProto|null); + + /** AttributeProto sparseTensor */ + sparseTensor?: (onnx.ISparseTensorProto|null); + + /** AttributeProto tp */ + tp?: (onnx.ITypeProto|null); + + /** AttributeProto floats */ + floats?: (number[]|null); + + /** AttributeProto ints */ + ints?: ((number | Long)[]|null); + + /** AttributeProto strings */ + strings?: (Uint8Array[]|null); + + /** AttributeProto tensors */ + tensors?: (onnx.ITensorProto[]|null); + + /** AttributeProto graphs */ + graphs?: (onnx.IGraphProto[]|null); + + /** AttributeProto sparseTensors */ + sparseTensors?: (onnx.ISparseTensorProto[]|null); + + /** AttributeProto typeProtos */ + typeProtos?: (onnx.ITypeProto[]|null); + } + + /** Represents an AttributeProto. */ + class AttributeProto implements IAttributeProto { + /** + * Constructs a new AttributeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IAttributeProto); + + /** AttributeProto name. */ + public name: string; + + /** AttributeProto refAttrName. */ + public refAttrName: string; + + /** AttributeProto docString. */ + public docString: string; + + /** AttributeProto type. */ + public type: onnx.AttributeProto.AttributeType; + + /** AttributeProto f. */ + public f: number; + + /** AttributeProto i. */ + public i: (number|Long); + + /** AttributeProto s. */ + public s: Uint8Array; + + /** AttributeProto t. */ + public t?: (onnx.ITensorProto|null); + + /** AttributeProto g. */ + public g?: (onnx.IGraphProto|null); + + /** AttributeProto sparseTensor. */ + public sparseTensor?: (onnx.ISparseTensorProto|null); + + /** AttributeProto tp. */ + public tp?: (onnx.ITypeProto|null); + + /** AttributeProto floats. */ + public floats: number[]; + + /** AttributeProto ints. */ + public ints: (number|Long)[]; + + /** AttributeProto strings. */ + public strings: Uint8Array[]; + + /** AttributeProto tensors. */ + public tensors: onnx.ITensorProto[]; + + /** AttributeProto graphs. */ + public graphs: onnx.IGraphProto[]; + + /** AttributeProto sparseTensors. */ + public sparseTensors: onnx.ISparseTensorProto[]; + + /** AttributeProto typeProtos. */ + public typeProtos: onnx.ITypeProto[]; + + /** + * Creates a new AttributeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns AttributeProto instance + */ + public static create(properties?: onnx.IAttributeProto): onnx.AttributeProto; + + /** + * Encodes the specified AttributeProto message. Does not implicitly {@link onnx.AttributeProto.verify|verify} + * messages. + * @param message AttributeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IAttributeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified AttributeProto message, length delimited. Does not implicitly {@link + * onnx.AttributeProto.verify|verify} messages. + * @param message AttributeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IAttributeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes an AttributeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.AttributeProto; + + /** + * Decodes an AttributeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.AttributeProto; + + /** + * Verifies an AttributeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates an AttributeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns AttributeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.AttributeProto; + + /** + * Creates a plain object from an AttributeProto message. Also converts values to other types if specified. + * @param message AttributeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.AttributeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this AttributeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for AttributeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace AttributeProto { + + /** AttributeType enum. */ + enum AttributeType { + UNDEFINED = 0, + FLOAT = 1, + INT = 2, + STRING = 3, + TENSOR = 4, + GRAPH = 5, + SPARSE_TENSOR = 11, + TYPE_PROTO = 13, + FLOATS = 6, + INTS = 7, + STRINGS = 8, + TENSORS = 9, + GRAPHS = 10, + SPARSE_TENSORS = 12, + TYPE_PROTOS = 14 + } + } + + /** Properties of a ValueInfoProto. */ + interface IValueInfoProto { + /** ValueInfoProto name */ + name?: (string|null); + + /** ValueInfoProto type */ + type?: (onnx.ITypeProto|null); + + /** ValueInfoProto docString */ + docString?: (string|null); + } + + /** Represents a ValueInfoProto. */ + class ValueInfoProto implements IValueInfoProto { + /** + * Constructs a new ValueInfoProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IValueInfoProto); + + /** ValueInfoProto name. */ + public name: string; + + /** ValueInfoProto type. */ + public type?: (onnx.ITypeProto|null); + + /** ValueInfoProto docString. */ + public docString: string; + + /** + * Creates a new ValueInfoProto instance using the specified properties. + * @param [properties] Properties to set + * @returns ValueInfoProto instance + */ + public static create(properties?: onnx.IValueInfoProto): onnx.ValueInfoProto; + + /** + * Encodes the specified ValueInfoProto message. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} + * messages. + * @param message ValueInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IValueInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified ValueInfoProto message, length delimited. Does not implicitly {@link + * onnx.ValueInfoProto.verify|verify} messages. + * @param message ValueInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IValueInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.ValueInfoProto; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.ValueInfoProto; + + /** + * Verifies a ValueInfoProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a ValueInfoProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns ValueInfoProto + */ + public static fromObject(object: {[k: string]: any}): onnx.ValueInfoProto; + + /** + * Creates a plain object from a ValueInfoProto message. Also converts values to other types if specified. + * @param message ValueInfoProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.ValueInfoProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this ValueInfoProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for ValueInfoProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a NodeProto. */ + interface INodeProto { + /** NodeProto input */ + input?: (string[]|null); + + /** NodeProto output */ + output?: (string[]|null); + + /** NodeProto name */ + name?: (string|null); + + /** NodeProto opType */ + opType?: (string|null); + + /** NodeProto domain */ + domain?: (string|null); + + /** NodeProto attribute */ + attribute?: (onnx.IAttributeProto[]|null); + + /** NodeProto docString */ + docString?: (string|null); + } + + /** Represents a NodeProto. */ + class NodeProto implements INodeProto { + /** + * Constructs a new NodeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.INodeProto); + + /** NodeProto input. */ + public input: string[]; + + /** NodeProto output. */ + public output: string[]; + + /** NodeProto name. */ + public name: string; + + /** NodeProto opType. */ + public opType: string; + + /** NodeProto domain. */ + public domain: string; + + /** NodeProto attribute. */ + public attribute: onnx.IAttributeProto[]; + + /** NodeProto docString. */ + public docString: string; + + /** + * Creates a new NodeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns NodeProto instance + */ + public static create(properties?: onnx.INodeProto): onnx.NodeProto; + + /** + * Encodes the specified NodeProto message. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @param message NodeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.INodeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified NodeProto message, length delimited. Does not implicitly {@link + * onnx.NodeProto.verify|verify} messages. + * @param message NodeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.INodeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a NodeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.NodeProto; + + /** + * Decodes a NodeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.NodeProto; + + /** + * Verifies a NodeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a NodeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns NodeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.NodeProto; + + /** + * Creates a plain object from a NodeProto message. Also converts values to other types if specified. + * @param message NodeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.NodeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this NodeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for NodeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TrainingInfoProto. */ + interface ITrainingInfoProto { + /** TrainingInfoProto initialization */ + initialization?: (onnx.IGraphProto|null); + + /** TrainingInfoProto algorithm */ + algorithm?: (onnx.IGraphProto|null); + + /** TrainingInfoProto initializationBinding */ + initializationBinding?: (onnx.IStringStringEntryProto[]|null); + + /** TrainingInfoProto updateBinding */ + updateBinding?: (onnx.IStringStringEntryProto[]|null); + } + + /** Represents a TrainingInfoProto. */ + class TrainingInfoProto implements ITrainingInfoProto { + /** + * Constructs a new TrainingInfoProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITrainingInfoProto); + + /** TrainingInfoProto initialization. */ + public initialization?: (onnx.IGraphProto|null); + + /** TrainingInfoProto algorithm. */ + public algorithm?: (onnx.IGraphProto|null); + + /** TrainingInfoProto initializationBinding. */ + public initializationBinding: onnx.IStringStringEntryProto[]; + + /** TrainingInfoProto updateBinding. */ + public updateBinding: onnx.IStringStringEntryProto[]; + + /** + * Creates a new TrainingInfoProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TrainingInfoProto instance + */ + public static create(properties?: onnx.ITrainingInfoProto): onnx.TrainingInfoProto; + + /** + * Encodes the specified TrainingInfoProto message. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} + * messages. + * @param message TrainingInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITrainingInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TrainingInfoProto message, length delimited. Does not implicitly {@link + * onnx.TrainingInfoProto.verify|verify} messages. + * @param message TrainingInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITrainingInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TrainingInfoProto; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TrainingInfoProto; + + /** + * Verifies a TrainingInfoProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TrainingInfoProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TrainingInfoProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TrainingInfoProto; + + /** + * Creates a plain object from a TrainingInfoProto message. Also converts values to other types if specified. + * @param message TrainingInfoProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TrainingInfoProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TrainingInfoProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TrainingInfoProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a ModelProto. */ + interface IModelProto { + /** ModelProto irVersion */ + irVersion?: (number|Long|null); + + /** ModelProto opsetImport */ + opsetImport?: (onnx.IOperatorSetIdProto[]|null); + + /** ModelProto producerName */ + producerName?: (string|null); + + /** ModelProto producerVersion */ + producerVersion?: (string|null); + + /** ModelProto domain */ + domain?: (string|null); + + /** ModelProto modelVersion */ + modelVersion?: (number|Long|null); + + /** ModelProto docString */ + docString?: (string|null); + + /** ModelProto graph */ + graph?: (onnx.IGraphProto|null); + + /** ModelProto metadataProps */ + metadataProps?: (onnx.IStringStringEntryProto[]|null); + + /** ModelProto trainingInfo */ + trainingInfo?: (onnx.ITrainingInfoProto[]|null); + + /** ModelProto functions */ + functions?: (onnx.IFunctionProto[]|null); + } + + /** Represents a ModelProto. */ + class ModelProto implements IModelProto { + /** + * Constructs a new ModelProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IModelProto); + + /** ModelProto irVersion. */ + public irVersion: (number|Long); + + /** ModelProto opsetImport. */ + public opsetImport: onnx.IOperatorSetIdProto[]; + + /** ModelProto producerName. */ + public producerName: string; + + /** ModelProto producerVersion. */ + public producerVersion: string; + + /** ModelProto domain. */ + public domain: string; + + /** ModelProto modelVersion. */ + public modelVersion: (number|Long); + + /** ModelProto docString. */ + public docString: string; + + /** ModelProto graph. */ + public graph?: (onnx.IGraphProto|null); + + /** ModelProto metadataProps. */ + public metadataProps: onnx.IStringStringEntryProto[]; + + /** ModelProto trainingInfo. */ + public trainingInfo: onnx.ITrainingInfoProto[]; + + /** ModelProto functions. */ + public functions: onnx.IFunctionProto[]; + + /** + * Creates a new ModelProto instance using the specified properties. + * @param [properties] Properties to set + * @returns ModelProto instance + */ + public static create(properties?: onnx.IModelProto): onnx.ModelProto; + + /** + * Encodes the specified ModelProto message. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @param message ModelProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IModelProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified ModelProto message, length delimited. Does not implicitly {@link + * onnx.ModelProto.verify|verify} messages. + * @param message ModelProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IModelProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a ModelProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.ModelProto; + + /** + * Decodes a ModelProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.ModelProto; + + /** + * Verifies a ModelProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a ModelProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns ModelProto + */ + public static fromObject(object: {[k: string]: any}): onnx.ModelProto; + + /** + * Creates a plain object from a ModelProto message. Also converts values to other types if specified. + * @param message ModelProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.ModelProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this ModelProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for ModelProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a StringStringEntryProto. */ + interface IStringStringEntryProto { + /** StringStringEntryProto key */ + key?: (string|null); + + /** StringStringEntryProto value */ + value?: (string|null); + } + + /** Represents a StringStringEntryProto. */ + class StringStringEntryProto implements IStringStringEntryProto { + /** + * Constructs a new StringStringEntryProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IStringStringEntryProto); + + /** StringStringEntryProto key. */ + public key: string; + + /** StringStringEntryProto value. */ + public value: string; + + /** + * Creates a new StringStringEntryProto instance using the specified properties. + * @param [properties] Properties to set + * @returns StringStringEntryProto instance + */ + public static create(properties?: onnx.IStringStringEntryProto): onnx.StringStringEntryProto; + + /** + * Encodes the specified StringStringEntryProto message. Does not implicitly {@link + * onnx.StringStringEntryProto.verify|verify} messages. + * @param message StringStringEntryProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IStringStringEntryProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified StringStringEntryProto message, length delimited. Does not implicitly {@link + * onnx.StringStringEntryProto.verify|verify} messages. + * @param message StringStringEntryProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IStringStringEntryProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.StringStringEntryProto; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.StringStringEntryProto; + + /** + * Verifies a StringStringEntryProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a StringStringEntryProto message from a plain object. Also converts values to their respective internal + * types. + * @param object Plain object + * @returns StringStringEntryProto + */ + public static fromObject(object: {[k: string]: any}): onnx.StringStringEntryProto; + + /** + * Creates a plain object from a StringStringEntryProto message. Also converts values to other types if specified. + * @param message StringStringEntryProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.StringStringEntryProto, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this StringStringEntryProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for StringStringEntryProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TensorAnnotation. */ + interface ITensorAnnotation { + /** TensorAnnotation tensorName */ + tensorName?: (string|null); + + /** TensorAnnotation quantParameterTensorNames */ + quantParameterTensorNames?: (onnx.IStringStringEntryProto[]|null); + } + + /** Represents a TensorAnnotation. */ + class TensorAnnotation implements ITensorAnnotation { + /** + * Constructs a new TensorAnnotation. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITensorAnnotation); + + /** TensorAnnotation tensorName. */ + public tensorName: string; + + /** TensorAnnotation quantParameterTensorNames. */ + public quantParameterTensorNames: onnx.IStringStringEntryProto[]; + + /** + * Creates a new TensorAnnotation instance using the specified properties. + * @param [properties] Properties to set + * @returns TensorAnnotation instance + */ + public static create(properties?: onnx.ITensorAnnotation): onnx.TensorAnnotation; + + /** + * Encodes the specified TensorAnnotation message. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} + * messages. + * @param message TensorAnnotation message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITensorAnnotation, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TensorAnnotation message, length delimited. Does not implicitly {@link + * onnx.TensorAnnotation.verify|verify} messages. + * @param message TensorAnnotation message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITensorAnnotation, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorAnnotation; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorAnnotation; + + /** + * Verifies a TensorAnnotation message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TensorAnnotation message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TensorAnnotation + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorAnnotation; + + /** + * Creates a plain object from a TensorAnnotation message. Also converts values to other types if specified. + * @param message TensorAnnotation + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorAnnotation, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TensorAnnotation to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TensorAnnotation + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a GraphProto. */ + interface IGraphProto { + /** GraphProto node */ + node?: (onnx.INodeProto[]|null); + + /** GraphProto name */ + name?: (string|null); + + /** GraphProto initializer */ + initializer?: (onnx.ITensorProto[]|null); + + /** GraphProto sparseInitializer */ + sparseInitializer?: (onnx.ISparseTensorProto[]|null); + + /** GraphProto docString */ + docString?: (string|null); + + /** GraphProto input */ + input?: (onnx.IValueInfoProto[]|null); + + /** GraphProto output */ + output?: (onnx.IValueInfoProto[]|null); + + /** GraphProto valueInfo */ + valueInfo?: (onnx.IValueInfoProto[]|null); + + /** GraphProto quantizationAnnotation */ + quantizationAnnotation?: (onnx.ITensorAnnotation[]|null); + } + + /** Represents a GraphProto. */ + class GraphProto implements IGraphProto { + /** + * Constructs a new GraphProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IGraphProto); + + /** GraphProto node. */ + public node: onnx.INodeProto[]; + + /** GraphProto name. */ + public name: string; + + /** GraphProto initializer. */ + public initializer: onnx.ITensorProto[]; + + /** GraphProto sparseInitializer. */ + public sparseInitializer: onnx.ISparseTensorProto[]; + + /** GraphProto docString. */ + public docString: string; + + /** GraphProto input. */ + public input: onnx.IValueInfoProto[]; + + /** GraphProto output. */ + public output: onnx.IValueInfoProto[]; + + /** GraphProto valueInfo. */ + public valueInfo: onnx.IValueInfoProto[]; + + /** GraphProto quantizationAnnotation. */ + public quantizationAnnotation: onnx.ITensorAnnotation[]; + + /** + * Creates a new GraphProto instance using the specified properties. + * @param [properties] Properties to set + * @returns GraphProto instance + */ + public static create(properties?: onnx.IGraphProto): onnx.GraphProto; + + /** + * Encodes the specified GraphProto message. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @param message GraphProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IGraphProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified GraphProto message, length delimited. Does not implicitly {@link + * onnx.GraphProto.verify|verify} messages. + * @param message GraphProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IGraphProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a GraphProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.GraphProto; + + /** + * Decodes a GraphProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.GraphProto; + + /** + * Verifies a GraphProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a GraphProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns GraphProto + */ + public static fromObject(object: {[k: string]: any}): onnx.GraphProto; + + /** + * Creates a plain object from a GraphProto message. Also converts values to other types if specified. + * @param message GraphProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.GraphProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this GraphProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for GraphProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TensorProto. */ + interface ITensorProto { + /** TensorProto dims */ + dims?: ((number | Long)[]|null); + + /** TensorProto dataType */ + dataType?: (number|null); + + /** TensorProto segment */ + segment?: (onnx.TensorProto.ISegment|null); + + /** TensorProto floatData */ + floatData?: (number[]|null); + + /** TensorProto int32Data */ + int32Data?: (number[]|null); + + /** TensorProto stringData */ + stringData?: (Uint8Array[]|null); + + /** TensorProto int64Data */ + int64Data?: ((number | Long)[]|null); + + /** TensorProto name */ + name?: (string|null); + + /** TensorProto docString */ + docString?: (string|null); + + /** TensorProto rawData */ + rawData?: (Uint8Array|null); + + /** TensorProto externalData */ + externalData?: (onnx.IStringStringEntryProto[]|null); + + /** TensorProto dataLocation */ + dataLocation?: (onnx.TensorProto.DataLocation|null); + + /** TensorProto doubleData */ + doubleData?: (number[]|null); + + /** TensorProto uint64Data */ + uint64Data?: ((number | Long)[]|null); + } + + /** Represents a TensorProto. */ + class TensorProto implements ITensorProto { + /** + * Constructs a new TensorProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITensorProto); + + /** TensorProto dims. */ + public dims: (number|Long)[]; + + /** TensorProto dataType. */ + public dataType: number; + + /** TensorProto segment. */ + public segment?: (onnx.TensorProto.ISegment|null); + + /** TensorProto floatData. */ + public floatData: number[]; + + /** TensorProto int32Data. */ + public int32Data: number[]; + + /** TensorProto stringData. */ + public stringData: Uint8Array[]; + + /** TensorProto int64Data. */ + public int64Data: (number|Long)[]; + + /** TensorProto name. */ + public name: string; + + /** TensorProto docString. */ + public docString: string; + + /** TensorProto rawData. */ + public rawData: Uint8Array; + + /** TensorProto externalData. */ + public externalData: onnx.IStringStringEntryProto[]; + + /** TensorProto dataLocation. */ + public dataLocation: onnx.TensorProto.DataLocation; + + /** TensorProto doubleData. */ + public doubleData: number[]; + + /** TensorProto uint64Data. */ + public uint64Data: (number|Long)[]; + + /** + * Creates a new TensorProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TensorProto instance + */ + public static create(properties?: onnx.ITensorProto): onnx.TensorProto; + + /** + * Encodes the specified TensorProto message. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @param message TensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TensorProto message, length delimited. Does not implicitly {@link + * onnx.TensorProto.verify|verify} messages. + * @param message TensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TensorProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorProto; + + /** + * Decodes a TensorProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorProto; + + /** + * Verifies a TensorProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TensorProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TensorProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorProto; + + /** + * Creates a plain object from a TensorProto message. Also converts values to other types if specified. + * @param message TensorProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TensorProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TensorProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace TensorProto { + + /** DataType enum. */ + enum DataType { + UNDEFINED = 0, + FLOAT = 1, + UINT8 = 2, + INT8 = 3, + UINT16 = 4, + INT16 = 5, + INT32 = 6, + INT64 = 7, + STRING = 8, + BOOL = 9, + FLOAT16 = 10, + DOUBLE = 11, + UINT32 = 12, + UINT64 = 13, + COMPLEX64 = 14, + COMPLEX128 = 15, + BFLOAT16 = 16, + FLOAT8E4M3FN = 17, + FLOAT8E4M3FNUZ = 18, + FLOAT8E5M2 = 19, + FLOAT8E5M2FNUZ = 20 + } + + /** Properties of a Segment. */ + interface ISegment { + /** Segment begin */ + begin?: (number|Long|null); + + /** Segment end */ + end?: (number|Long|null); + } + + /** Represents a Segment. */ + class Segment implements ISegment { + /** + * Constructs a new Segment. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TensorProto.ISegment); + + /** Segment begin. */ + public begin: (number|Long); + + /** Segment end. */ + public end: (number|Long); + + /** + * Creates a new Segment instance using the specified properties. + * @param [properties] Properties to set + * @returns Segment instance + */ + public static create(properties?: onnx.TensorProto.ISegment): onnx.TensorProto.Segment; + + /** + * Encodes the specified Segment message. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} + * messages. + * @param message Segment message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TensorProto.ISegment, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Segment message, length delimited. Does not implicitly {@link + * onnx.TensorProto.Segment.verify|verify} messages. + * @param message Segment message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TensorProto.ISegment, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Segment message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorProto.Segment; + + /** + * Decodes a Segment message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorProto.Segment; + + /** + * Verifies a Segment message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Segment message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Segment + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorProto.Segment; + + /** + * Creates a plain object from a Segment message. Also converts values to other types if specified. + * @param message Segment + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorProto.Segment, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Segment to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Segment + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** DataLocation enum. */ + enum DataLocation { DEFAULT = 0, EXTERNAL = 1 } + } + + /** Properties of a SparseTensorProto. */ + interface ISparseTensorProto { + /** SparseTensorProto values */ + values?: (onnx.ITensorProto|null); + + /** SparseTensorProto indices */ + indices?: (onnx.ITensorProto|null); + + /** SparseTensorProto dims */ + dims?: ((number | Long)[]|null); + } + + /** Represents a SparseTensorProto. */ + class SparseTensorProto implements ISparseTensorProto { + /** + * Constructs a new SparseTensorProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ISparseTensorProto); + + /** SparseTensorProto values. */ + public values?: (onnx.ITensorProto|null); + + /** SparseTensorProto indices. */ + public indices?: (onnx.ITensorProto|null); + + /** SparseTensorProto dims. */ + public dims: (number|Long)[]; + + /** + * Creates a new SparseTensorProto instance using the specified properties. + * @param [properties] Properties to set + * @returns SparseTensorProto instance + */ + public static create(properties?: onnx.ISparseTensorProto): onnx.SparseTensorProto; + + /** + * Encodes the specified SparseTensorProto message. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} + * messages. + * @param message SparseTensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ISparseTensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified SparseTensorProto message, length delimited. Does not implicitly {@link + * onnx.SparseTensorProto.verify|verify} messages. + * @param message SparseTensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ISparseTensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.SparseTensorProto; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.SparseTensorProto; + + /** + * Verifies a SparseTensorProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a SparseTensorProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns SparseTensorProto + */ + public static fromObject(object: {[k: string]: any}): onnx.SparseTensorProto; + + /** + * Creates a plain object from a SparseTensorProto message. Also converts values to other types if specified. + * @param message SparseTensorProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.SparseTensorProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this SparseTensorProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for SparseTensorProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TensorShapeProto. */ + interface ITensorShapeProto { + /** TensorShapeProto dim */ + dim?: (onnx.TensorShapeProto.IDimension[]|null); + } + + /** Represents a TensorShapeProto. */ + class TensorShapeProto implements ITensorShapeProto { + /** + * Constructs a new TensorShapeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITensorShapeProto); + + /** TensorShapeProto dim. */ + public dim: onnx.TensorShapeProto.IDimension[]; + + /** + * Creates a new TensorShapeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TensorShapeProto instance + */ + public static create(properties?: onnx.ITensorShapeProto): onnx.TensorShapeProto; + + /** + * Encodes the specified TensorShapeProto message. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} + * messages. + * @param message TensorShapeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITensorShapeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TensorShapeProto message, length delimited. Does not implicitly {@link + * onnx.TensorShapeProto.verify|verify} messages. + * @param message TensorShapeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITensorShapeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorShapeProto; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorShapeProto; + + /** + * Verifies a TensorShapeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TensorShapeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TensorShapeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorShapeProto; + + /** + * Creates a plain object from a TensorShapeProto message. Also converts values to other types if specified. + * @param message TensorShapeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorShapeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TensorShapeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TensorShapeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace TensorShapeProto { + + /** Properties of a Dimension. */ + interface IDimension { + /** Dimension dimValue */ + dimValue?: (number|Long|null); + + /** Dimension dimParam */ + dimParam?: (string|null); + + /** Dimension denotation */ + denotation?: (string|null); + } + + /** Represents a Dimension. */ + class Dimension implements IDimension { + /** + * Constructs a new Dimension. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TensorShapeProto.IDimension); + + /** Dimension dimValue. */ + public dimValue?: (number|Long|null); + + /** Dimension dimParam. */ + public dimParam?: (string|null); + + /** Dimension denotation. */ + public denotation: string; + + /** Dimension value. */ + public value?: ('dimValue'|'dimParam'); + + /** + * Creates a new Dimension instance using the specified properties. + * @param [properties] Properties to set + * @returns Dimension instance + */ + public static create(properties?: onnx.TensorShapeProto.IDimension): onnx.TensorShapeProto.Dimension; + + /** + * Encodes the specified Dimension message. Does not implicitly {@link + * onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @param message Dimension message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TensorShapeProto.IDimension, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Dimension message, length delimited. Does not implicitly {@link + * onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @param message Dimension message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TensorShapeProto.IDimension, writer?: $protobuf.Writer): + $protobuf.Writer; + + /** + * Decodes a Dimension message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorShapeProto.Dimension; + + /** + * Decodes a Dimension message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorShapeProto.Dimension; + + /** + * Verifies a Dimension message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Dimension message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Dimension + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorShapeProto.Dimension; + + /** + * Creates a plain object from a Dimension message. Also converts values to other types if specified. + * @param message Dimension + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorShapeProto.Dimension, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Dimension to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Dimension + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + } + + /** Properties of a TypeProto. */ + interface ITypeProto { + /** TypeProto tensorType */ + tensorType?: (onnx.TypeProto.ITensor|null); + + /** TypeProto sequenceType */ + sequenceType?: (onnx.TypeProto.ISequence|null); + + /** TypeProto mapType */ + mapType?: (onnx.TypeProto.IMap|null); + + /** TypeProto optionalType */ + optionalType?: (onnx.TypeProto.IOptional|null); + + /** TypeProto sparseTensorType */ + sparseTensorType?: (onnx.TypeProto.ISparseTensor|null); + + /** TypeProto denotation */ + denotation?: (string|null); + } + + /** Represents a TypeProto. */ + class TypeProto implements ITypeProto { + /** + * Constructs a new TypeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITypeProto); + + /** TypeProto tensorType. */ + public tensorType?: (onnx.TypeProto.ITensor|null); + + /** TypeProto sequenceType. */ + public sequenceType?: (onnx.TypeProto.ISequence|null); + + /** TypeProto mapType. */ + public mapType?: (onnx.TypeProto.IMap|null); + + /** TypeProto optionalType. */ + public optionalType?: (onnx.TypeProto.IOptional|null); + + /** TypeProto sparseTensorType. */ + public sparseTensorType?: (onnx.TypeProto.ISparseTensor|null); + + /** TypeProto denotation. */ + public denotation: string; + + /** TypeProto value. */ + public value?: ('tensorType'|'sequenceType'|'mapType'|'optionalType'|'sparseTensorType'); + + /** + * Creates a new TypeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TypeProto instance + */ + public static create(properties?: onnx.ITypeProto): onnx.TypeProto; + + /** + * Encodes the specified TypeProto message. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @param message TypeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITypeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TypeProto message, length delimited. Does not implicitly {@link + * onnx.TypeProto.verify|verify} messages. + * @param message TypeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITypeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TypeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto; + + /** + * Decodes a TypeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto; + + /** + * Verifies a TypeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TypeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TypeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto; + + /** + * Creates a plain object from a TypeProto message. Also converts values to other types if specified. + * @param message TypeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TypeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TypeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace TypeProto { + + /** Properties of a Tensor. */ + interface ITensor { + /** Tensor elemType */ + elemType?: (number|null); + + /** Tensor shape */ + shape?: (onnx.ITensorShapeProto|null); + } + + /** Represents a Tensor. */ + class Tensor implements ITensor { + /** + * Constructs a new Tensor. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.ITensor); + + /** Tensor elemType. */ + public elemType: number; + + /** Tensor shape. */ + public shape?: (onnx.ITensorShapeProto|null); + + /** + * Creates a new Tensor instance using the specified properties. + * @param [properties] Properties to set + * @returns Tensor instance + */ + public static create(properties?: onnx.TypeProto.ITensor): onnx.TypeProto.Tensor; + + /** + * Encodes the specified Tensor message. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @param message Tensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.ITensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Tensor message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Tensor.verify|verify} messages. + * @param message Tensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.ITensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Tensor message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Tensor; + + /** + * Decodes a Tensor message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Tensor; + + /** + * Verifies a Tensor message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Tensor message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Tensor + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Tensor; + + /** + * Creates a plain object from a Tensor message. Also converts values to other types if specified. + * @param message Tensor + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Tensor, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Tensor to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Tensor + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a Sequence. */ + interface ISequence { + /** Sequence elemType */ + elemType?: (onnx.ITypeProto|null); + } + + /** Represents a Sequence. */ + class Sequence implements ISequence { + /** + * Constructs a new Sequence. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.ISequence); + + /** Sequence elemType. */ + public elemType?: (onnx.ITypeProto|null); + + /** + * Creates a new Sequence instance using the specified properties. + * @param [properties] Properties to set + * @returns Sequence instance + */ + public static create(properties?: onnx.TypeProto.ISequence): onnx.TypeProto.Sequence; + + /** + * Encodes the specified Sequence message. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} + * messages. + * @param message Sequence message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.ISequence, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Sequence message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Sequence.verify|verify} messages. + * @param message Sequence message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.ISequence, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Sequence message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Sequence; + + /** + * Decodes a Sequence message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Sequence; + + /** + * Verifies a Sequence message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Sequence message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Sequence + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Sequence; + + /** + * Creates a plain object from a Sequence message. Also converts values to other types if specified. + * @param message Sequence + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Sequence, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Sequence to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Sequence + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a Map. */ + interface IMap { + /** Map keyType */ + keyType?: (number|null); + + /** Map valueType */ + valueType?: (onnx.ITypeProto|null); + } + + /** Represents a Map. */ + class Map implements IMap { + /** + * Constructs a new Map. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.IMap); + + /** Map keyType. */ + public keyType: number; + + /** Map valueType. */ + public valueType?: (onnx.ITypeProto|null); + + /** + * Creates a new Map instance using the specified properties. + * @param [properties] Properties to set + * @returns Map instance + */ + public static create(properties?: onnx.TypeProto.IMap): onnx.TypeProto.Map; + + /** + * Encodes the specified Map message. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @param message Map message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.IMap, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Map message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Map.verify|verify} messages. + * @param message Map message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.IMap, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Map message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Map; + + /** + * Decodes a Map message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Map; + + /** + * Verifies a Map message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Map message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Map + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Map; + + /** + * Creates a plain object from a Map message. Also converts values to other types if specified. + * @param message Map + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Map, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this Map to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Map + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of an Optional. */ + interface IOptional { + /** Optional elemType */ + elemType?: (onnx.ITypeProto|null); + } + + /** Represents an Optional. */ + class Optional implements IOptional { + /** + * Constructs a new Optional. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.IOptional); + + /** Optional elemType. */ + public elemType?: (onnx.ITypeProto|null); + + /** + * Creates a new Optional instance using the specified properties. + * @param [properties] Properties to set + * @returns Optional instance + */ + public static create(properties?: onnx.TypeProto.IOptional): onnx.TypeProto.Optional; + + /** + * Encodes the specified Optional message. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} + * messages. + * @param message Optional message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.IOptional, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Optional message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Optional.verify|verify} messages. + * @param message Optional message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.IOptional, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes an Optional message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Optional; + + /** + * Decodes an Optional message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Optional; + + /** + * Verifies an Optional message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates an Optional message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Optional + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Optional; + + /** + * Creates a plain object from an Optional message. Also converts values to other types if specified. + * @param message Optional + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Optional, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Optional to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Optional + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a SparseTensor. */ + interface ISparseTensor { + /** SparseTensor elemType */ + elemType?: (number|null); + + /** SparseTensor shape */ + shape?: (onnx.ITensorShapeProto|null); + } + + /** Represents a SparseTensor. */ + class SparseTensor implements ISparseTensor { + /** + * Constructs a new SparseTensor. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.ISparseTensor); + + /** SparseTensor elemType. */ + public elemType: number; + + /** SparseTensor shape. */ + public shape?: (onnx.ITensorShapeProto|null); + + /** + * Creates a new SparseTensor instance using the specified properties. + * @param [properties] Properties to set + * @returns SparseTensor instance + */ + public static create(properties?: onnx.TypeProto.ISparseTensor): onnx.TypeProto.SparseTensor; + + /** + * Encodes the specified SparseTensor message. Does not implicitly {@link + * onnx.TypeProto.SparseTensor.verify|verify} messages. + * @param message SparseTensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.ISparseTensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified SparseTensor message, length delimited. Does not implicitly {@link + * onnx.TypeProto.SparseTensor.verify|verify} messages. + * @param message SparseTensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.ISparseTensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a SparseTensor message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.SparseTensor; + + /** + * Decodes a SparseTensor message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.SparseTensor; + + /** + * Verifies a SparseTensor message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a SparseTensor message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns SparseTensor + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.SparseTensor; + + /** + * Creates a plain object from a SparseTensor message. Also converts values to other types if specified. + * @param message SparseTensor + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.SparseTensor, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this SparseTensor to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for SparseTensor + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + } + + /** Properties of an OperatorSetIdProto. */ + interface IOperatorSetIdProto { + /** OperatorSetIdProto domain */ + domain?: (string|null); + + /** OperatorSetIdProto version */ + version?: (number|Long|null); + } + + /** Represents an OperatorSetIdProto. */ + class OperatorSetIdProto implements IOperatorSetIdProto { + /** + * Constructs a new OperatorSetIdProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IOperatorSetIdProto); + + /** OperatorSetIdProto domain. */ + public domain: string; + + /** OperatorSetIdProto version. */ + public version: (number|Long); + + /** + * Creates a new OperatorSetIdProto instance using the specified properties. + * @param [properties] Properties to set + * @returns OperatorSetIdProto instance + */ + public static create(properties?: onnx.IOperatorSetIdProto): onnx.OperatorSetIdProto; + + /** + * Encodes the specified OperatorSetIdProto message. Does not implicitly {@link + * onnx.OperatorSetIdProto.verify|verify} messages. + * @param message OperatorSetIdProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IOperatorSetIdProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified OperatorSetIdProto message, length delimited. Does not implicitly {@link + * onnx.OperatorSetIdProto.verify|verify} messages. + * @param message OperatorSetIdProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IOperatorSetIdProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.OperatorSetIdProto; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.OperatorSetIdProto; + + /** + * Verifies an OperatorSetIdProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates an OperatorSetIdProto message from a plain object. Also converts values to their respective internal + * types. + * @param object Plain object + * @returns OperatorSetIdProto + */ + public static fromObject(object: {[k: string]: any}): onnx.OperatorSetIdProto; + + /** + * Creates a plain object from an OperatorSetIdProto message. Also converts values to other types if specified. + * @param message OperatorSetIdProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.OperatorSetIdProto, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this OperatorSetIdProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for OperatorSetIdProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** OperatorStatus enum. */ + enum OperatorStatus { EXPERIMENTAL = 0, STABLE = 1 } + + /** Properties of a FunctionProto. */ + interface IFunctionProto { + /** FunctionProto name */ + name?: (string|null); + + /** FunctionProto input */ + input?: (string[]|null); + + /** FunctionProto output */ + output?: (string[]|null); + + /** FunctionProto attribute */ + attribute?: (string[]|null); + + /** FunctionProto attributeProto */ + attributeProto?: (onnx.IAttributeProto[]|null); + + /** FunctionProto node */ + node?: (onnx.INodeProto[]|null); + + /** FunctionProto docString */ + docString?: (string|null); + + /** FunctionProto opsetImport */ + opsetImport?: (onnx.IOperatorSetIdProto[]|null); + + /** FunctionProto domain */ + domain?: (string|null); + } + + /** Represents a FunctionProto. */ + class FunctionProto implements IFunctionProto { + /** + * Constructs a new FunctionProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IFunctionProto); + + /** FunctionProto name. */ + public name: string; + + /** FunctionProto input. */ + public input: string[]; + + /** FunctionProto output. */ + public output: string[]; + + /** FunctionProto attribute. */ + public attribute: string[]; + + /** FunctionProto attributeProto. */ + public attributeProto: onnx.IAttributeProto[]; + + /** FunctionProto node. */ + public node: onnx.INodeProto[]; + + /** FunctionProto docString. */ + public docString: string; + + /** FunctionProto opsetImport. */ + public opsetImport: onnx.IOperatorSetIdProto[]; + + /** FunctionProto domain. */ + public domain: string; + + /** + * Creates a new FunctionProto instance using the specified properties. + * @param [properties] Properties to set + * @returns FunctionProto instance + */ + public static create(properties?: onnx.IFunctionProto): onnx.FunctionProto; + + /** + * Encodes the specified FunctionProto message. Does not implicitly {@link onnx.FunctionProto.verify|verify} + * messages. + * @param message FunctionProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IFunctionProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified FunctionProto message, length delimited. Does not implicitly {@link + * onnx.FunctionProto.verify|verify} messages. + * @param message FunctionProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IFunctionProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a FunctionProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.FunctionProto; + + /** + * Decodes a FunctionProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.FunctionProto; + + /** + * Verifies a FunctionProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a FunctionProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns FunctionProto + */ + public static fromObject(object: {[k: string]: any}): onnx.FunctionProto; + + /** + * Creates a plain object from a FunctionProto message. Also converts values to other types if specified. + * @param message FunctionProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.FunctionProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this FunctionProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for FunctionProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } +} diff --git a/js/node/test/ort-schema/protobuf/onnx.js b/js/node/test/ort-schema/protobuf/onnx.js new file mode 100644 index 0000000000000..681855132d4e8 --- /dev/null +++ b/js/node/test/ort-schema/protobuf/onnx.js @@ -0,0 +1,7658 @@ +/*eslint-disable block-scoped-var, id-length, no-control-regex, no-magic-numbers, no-prototype-builtins, no-redeclare, no-shadow, no-var, sort-vars*/ +"use strict"; + +var $protobuf = require("protobufjs/minimal"); + +// Common aliases +var $Reader = $protobuf.Reader, $Writer = $protobuf.Writer, $util = $protobuf.util; + +// Exported root namespace +var $root = $protobuf.roots["default"] || ($protobuf.roots["default"] = {}); + +$root.onnx = (function() { + + /** + * Namespace onnx. + * @exports onnx + * @namespace + */ + var onnx = {}; + + /** + * Version enum. + * @name onnx.Version + * @enum {number} + * @property {number} _START_VERSION=0 _START_VERSION value + * @property {number} IR_VERSION_2017_10_10=1 IR_VERSION_2017_10_10 value + * @property {number} IR_VERSION_2017_10_30=2 IR_VERSION_2017_10_30 value + * @property {number} IR_VERSION_2017_11_3=3 IR_VERSION_2017_11_3 value + * @property {number} IR_VERSION_2019_1_22=4 IR_VERSION_2019_1_22 value + * @property {number} IR_VERSION_2019_3_18=5 IR_VERSION_2019_3_18 value + * @property {number} IR_VERSION_2019_9_19=6 IR_VERSION_2019_9_19 value + * @property {number} IR_VERSION_2020_5_8=7 IR_VERSION_2020_5_8 value + * @property {number} IR_VERSION_2021_7_30=8 IR_VERSION_2021_7_30 value + * @property {number} IR_VERSION=9 IR_VERSION value + */ + onnx.Version = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "_START_VERSION"] = 0; + values[valuesById[1] = "IR_VERSION_2017_10_10"] = 1; + values[valuesById[2] = "IR_VERSION_2017_10_30"] = 2; + values[valuesById[3] = "IR_VERSION_2017_11_3"] = 3; + values[valuesById[4] = "IR_VERSION_2019_1_22"] = 4; + values[valuesById[5] = "IR_VERSION_2019_3_18"] = 5; + values[valuesById[6] = "IR_VERSION_2019_9_19"] = 6; + values[valuesById[7] = "IR_VERSION_2020_5_8"] = 7; + values[valuesById[8] = "IR_VERSION_2021_7_30"] = 8; + values[valuesById[9] = "IR_VERSION"] = 9; + return values; + })(); + + onnx.AttributeProto = (function() { + + /** + * Properties of an AttributeProto. + * @memberof onnx + * @interface IAttributeProto + * @property {string|null} [name] AttributeProto name + * @property {string|null} [refAttrName] AttributeProto refAttrName + * @property {string|null} [docString] AttributeProto docString + * @property {onnx.AttributeProto.AttributeType|null} [type] AttributeProto type + * @property {number|null} [f] AttributeProto f + * @property {number|Long|null} [i] AttributeProto i + * @property {Uint8Array|null} [s] AttributeProto s + * @property {onnx.ITensorProto|null} [t] AttributeProto t + * @property {onnx.IGraphProto|null} [g] AttributeProto g + * @property {onnx.ISparseTensorProto|null} [sparseTensor] AttributeProto sparseTensor + * @property {onnx.ITypeProto|null} [tp] AttributeProto tp + * @property {Array.|null} [floats] AttributeProto floats + * @property {Array.|null} [ints] AttributeProto ints + * @property {Array.|null} [strings] AttributeProto strings + * @property {Array.|null} [tensors] AttributeProto tensors + * @property {Array.|null} [graphs] AttributeProto graphs + * @property {Array.|null} [sparseTensors] AttributeProto sparseTensors + * @property {Array.|null} [typeProtos] AttributeProto typeProtos + */ + + /** + * Constructs a new AttributeProto. + * @memberof onnx + * @classdesc Represents an AttributeProto. + * @implements IAttributeProto + * @constructor + * @param {onnx.IAttributeProto=} [properties] Properties to set + */ + function AttributeProto(properties) { + this.floats = []; + this.ints = []; + this.strings = []; + this.tensors = []; + this.graphs = []; + this.sparseTensors = []; + this.typeProtos = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * AttributeProto name. + * @member {string} name + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.name = ""; + + /** + * AttributeProto refAttrName. + * @member {string} refAttrName + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.refAttrName = ""; + + /** + * AttributeProto docString. + * @member {string} docString + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.docString = ""; + + /** + * AttributeProto type. + * @member {onnx.AttributeProto.AttributeType} type + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.type = 0; + + /** + * AttributeProto f. + * @member {number} f + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.f = 0; + + /** + * AttributeProto i. + * @member {number|Long} i + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.i = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * AttributeProto s. + * @member {Uint8Array} s + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.s = $util.newBuffer([]); + + /** + * AttributeProto t. + * @member {onnx.ITensorProto|null|undefined} t + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.t = null; + + /** + * AttributeProto g. + * @member {onnx.IGraphProto|null|undefined} g + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.g = null; + + /** + * AttributeProto sparseTensor. + * @member {onnx.ISparseTensorProto|null|undefined} sparseTensor + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.sparseTensor = null; + + /** + * AttributeProto tp. + * @member {onnx.ITypeProto|null|undefined} tp + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.tp = null; + + /** + * AttributeProto floats. + * @member {Array.} floats + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.floats = $util.emptyArray; + + /** + * AttributeProto ints. + * @member {Array.} ints + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.ints = $util.emptyArray; + + /** + * AttributeProto strings. + * @member {Array.} strings + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.strings = $util.emptyArray; + + /** + * AttributeProto tensors. + * @member {Array.} tensors + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.tensors = $util.emptyArray; + + /** + * AttributeProto graphs. + * @member {Array.} graphs + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.graphs = $util.emptyArray; + + /** + * AttributeProto sparseTensors. + * @member {Array.} sparseTensors + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.sparseTensors = $util.emptyArray; + + /** + * AttributeProto typeProtos. + * @member {Array.} typeProtos + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.typeProtos = $util.emptyArray; + + /** + * Creates a new AttributeProto instance using the specified properties. + * @function create + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto=} [properties] Properties to set + * @returns {onnx.AttributeProto} AttributeProto instance + */ + AttributeProto.create = function create(properties) { + return new AttributeProto(properties); + }; + + /** + * Encodes the specified AttributeProto message. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. + * @function encode + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + AttributeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); + if (message.f != null && Object.hasOwnProperty.call(message, "f")) + writer.uint32(/* id 2, wireType 5 =*/21).float(message.f); + if (message.i != null && Object.hasOwnProperty.call(message, "i")) + writer.uint32(/* id 3, wireType 0 =*/24).int64(message.i); + if (message.s != null && Object.hasOwnProperty.call(message, "s")) + writer.uint32(/* id 4, wireType 2 =*/34).bytes(message.s); + if (message.t != null && Object.hasOwnProperty.call(message, "t")) + $root.onnx.TensorProto.encode(message.t, writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.g != null && Object.hasOwnProperty.call(message, "g")) + $root.onnx.GraphProto.encode(message.g, writer.uint32(/* id 6, wireType 2 =*/50).fork()).ldelim(); + if (message.floats != null && message.floats.length) { + writer.uint32(/* id 7, wireType 2 =*/58).fork(); + for (var i = 0; i < message.floats.length; ++i) + writer.float(message.floats[i]); + writer.ldelim(); + } + if (message.ints != null && message.ints.length) { + writer.uint32(/* id 8, wireType 2 =*/66).fork(); + for (var i = 0; i < message.ints.length; ++i) + writer.int64(message.ints[i]); + writer.ldelim(); + } + if (message.strings != null && message.strings.length) + for (var i = 0; i < message.strings.length; ++i) + writer.uint32(/* id 9, wireType 2 =*/74).bytes(message.strings[i]); + if (message.tensors != null && message.tensors.length) + for (var i = 0; i < message.tensors.length; ++i) + $root.onnx.TensorProto.encode(message.tensors[i], writer.uint32(/* id 10, wireType 2 =*/82).fork()).ldelim(); + if (message.graphs != null && message.graphs.length) + for (var i = 0; i < message.graphs.length; ++i) + $root.onnx.GraphProto.encode(message.graphs[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 13, wireType 2 =*/106).string(message.docString); + if (message.tp != null && Object.hasOwnProperty.call(message, "tp")) + $root.onnx.TypeProto.encode(message.tp, writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); + if (message.typeProtos != null && message.typeProtos.length) + for (var i = 0; i < message.typeProtos.length; ++i) + $root.onnx.TypeProto.encode(message.typeProtos[i], writer.uint32(/* id 15, wireType 2 =*/122).fork()).ldelim(); + if (message.type != null && Object.hasOwnProperty.call(message, "type")) + writer.uint32(/* id 20, wireType 0 =*/160).int32(message.type); + if (message.refAttrName != null && Object.hasOwnProperty.call(message, "refAttrName")) + writer.uint32(/* id 21, wireType 2 =*/170).string(message.refAttrName); + if (message.sparseTensor != null && Object.hasOwnProperty.call(message, "sparseTensor")) + $root.onnx.SparseTensorProto.encode(message.sparseTensor, writer.uint32(/* id 22, wireType 2 =*/178).fork()).ldelim(); + if (message.sparseTensors != null && message.sparseTensors.length) + for (var i = 0; i < message.sparseTensors.length; ++i) + $root.onnx.SparseTensorProto.encode(message.sparseTensors[i], writer.uint32(/* id 23, wireType 2 =*/186).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified AttributeProto message, length delimited. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + AttributeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an AttributeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.AttributeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.AttributeProto} AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + AttributeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.AttributeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 21: { + message.refAttrName = reader.string(); + break; + } + case 13: { + message.docString = reader.string(); + break; + } + case 20: { + message.type = reader.int32(); + break; + } + case 2: { + message.f = reader.float(); + break; + } + case 3: { + message.i = reader.int64(); + break; + } + case 4: { + message.s = reader.bytes(); + break; + } + case 5: { + message.t = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 6: { + message.g = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 22: { + message.sparseTensor = $root.onnx.SparseTensorProto.decode(reader, reader.uint32()); + break; + } + case 14: { + message.tp = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + case 7: { + if (!(message.floats && message.floats.length)) + message.floats = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.floats.push(reader.float()); + } else + message.floats.push(reader.float()); + break; + } + case 8: { + if (!(message.ints && message.ints.length)) + message.ints = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.ints.push(reader.int64()); + } else + message.ints.push(reader.int64()); + break; + } + case 9: { + if (!(message.strings && message.strings.length)) + message.strings = []; + message.strings.push(reader.bytes()); + break; + } + case 10: { + if (!(message.tensors && message.tensors.length)) + message.tensors = []; + message.tensors.push($root.onnx.TensorProto.decode(reader, reader.uint32())); + break; + } + case 11: { + if (!(message.graphs && message.graphs.length)) + message.graphs = []; + message.graphs.push($root.onnx.GraphProto.decode(reader, reader.uint32())); + break; + } + case 23: { + if (!(message.sparseTensors && message.sparseTensors.length)) + message.sparseTensors = []; + message.sparseTensors.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); + break; + } + case 15: { + if (!(message.typeProtos && message.typeProtos.length)) + message.typeProtos = []; + message.typeProtos.push($root.onnx.TypeProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an AttributeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.AttributeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.AttributeProto} AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + AttributeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an AttributeProto message. + * @function verify + * @memberof onnx.AttributeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + AttributeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.refAttrName != null && message.hasOwnProperty("refAttrName")) + if (!$util.isString(message.refAttrName)) + return "refAttrName: string expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.type != null && message.hasOwnProperty("type")) + switch (message.type) { + default: + return "type: enum value expected"; + case 0: + case 1: + case 2: + case 3: + case 4: + case 5: + case 11: + case 13: + case 6: + case 7: + case 8: + case 9: + case 10: + case 12: + case 14: + break; + } + if (message.f != null && message.hasOwnProperty("f")) + if (typeof message.f !== "number") + return "f: number expected"; + if (message.i != null && message.hasOwnProperty("i")) + if (!$util.isInteger(message.i) && !(message.i && $util.isInteger(message.i.low) && $util.isInteger(message.i.high))) + return "i: integer|Long expected"; + if (message.s != null && message.hasOwnProperty("s")) + if (!(message.s && typeof message.s.length === "number" || $util.isString(message.s))) + return "s: buffer expected"; + if (message.t != null && message.hasOwnProperty("t")) { + var error = $root.onnx.TensorProto.verify(message.t); + if (error) + return "t." + error; + } + if (message.g != null && message.hasOwnProperty("g")) { + var error = $root.onnx.GraphProto.verify(message.g); + if (error) + return "g." + error; + } + if (message.sparseTensor != null && message.hasOwnProperty("sparseTensor")) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseTensor); + if (error) + return "sparseTensor." + error; + } + if (message.tp != null && message.hasOwnProperty("tp")) { + var error = $root.onnx.TypeProto.verify(message.tp); + if (error) + return "tp." + error; + } + if (message.floats != null && message.hasOwnProperty("floats")) { + if (!Array.isArray(message.floats)) + return "floats: array expected"; + for (var i = 0; i < message.floats.length; ++i) + if (typeof message.floats[i] !== "number") + return "floats: number[] expected"; + } + if (message.ints != null && message.hasOwnProperty("ints")) { + if (!Array.isArray(message.ints)) + return "ints: array expected"; + for (var i = 0; i < message.ints.length; ++i) + if (!$util.isInteger(message.ints[i]) && !(message.ints[i] && $util.isInteger(message.ints[i].low) && $util.isInteger(message.ints[i].high))) + return "ints: integer|Long[] expected"; + } + if (message.strings != null && message.hasOwnProperty("strings")) { + if (!Array.isArray(message.strings)) + return "strings: array expected"; + for (var i = 0; i < message.strings.length; ++i) + if (!(message.strings[i] && typeof message.strings[i].length === "number" || $util.isString(message.strings[i]))) + return "strings: buffer[] expected"; + } + if (message.tensors != null && message.hasOwnProperty("tensors")) { + if (!Array.isArray(message.tensors)) + return "tensors: array expected"; + for (var i = 0; i < message.tensors.length; ++i) { + var error = $root.onnx.TensorProto.verify(message.tensors[i]); + if (error) + return "tensors." + error; + } + } + if (message.graphs != null && message.hasOwnProperty("graphs")) { + if (!Array.isArray(message.graphs)) + return "graphs: array expected"; + for (var i = 0; i < message.graphs.length; ++i) { + var error = $root.onnx.GraphProto.verify(message.graphs[i]); + if (error) + return "graphs." + error; + } + } + if (message.sparseTensors != null && message.hasOwnProperty("sparseTensors")) { + if (!Array.isArray(message.sparseTensors)) + return "sparseTensors: array expected"; + for (var i = 0; i < message.sparseTensors.length; ++i) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseTensors[i]); + if (error) + return "sparseTensors." + error; + } + } + if (message.typeProtos != null && message.hasOwnProperty("typeProtos")) { + if (!Array.isArray(message.typeProtos)) + return "typeProtos: array expected"; + for (var i = 0; i < message.typeProtos.length; ++i) { + var error = $root.onnx.TypeProto.verify(message.typeProtos[i]); + if (error) + return "typeProtos." + error; + } + } + return null; + }; + + /** + * Creates an AttributeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.AttributeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.AttributeProto} AttributeProto + */ + AttributeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.AttributeProto) + return object; + var message = new $root.onnx.AttributeProto(); + if (object.name != null) + message.name = String(object.name); + if (object.refAttrName != null) + message.refAttrName = String(object.refAttrName); + if (object.docString != null) + message.docString = String(object.docString); + switch (object.type) { + default: + if (typeof object.type === "number") { + message.type = object.type; + break; + } + break; + case "UNDEFINED": + case 0: + message.type = 0; + break; + case "FLOAT": + case 1: + message.type = 1; + break; + case "INT": + case 2: + message.type = 2; + break; + case "STRING": + case 3: + message.type = 3; + break; + case "TENSOR": + case 4: + message.type = 4; + break; + case "GRAPH": + case 5: + message.type = 5; + break; + case "SPARSE_TENSOR": + case 11: + message.type = 11; + break; + case "TYPE_PROTO": + case 13: + message.type = 13; + break; + case "FLOATS": + case 6: + message.type = 6; + break; + case "INTS": + case 7: + message.type = 7; + break; + case "STRINGS": + case 8: + message.type = 8; + break; + case "TENSORS": + case 9: + message.type = 9; + break; + case "GRAPHS": + case 10: + message.type = 10; + break; + case "SPARSE_TENSORS": + case 12: + message.type = 12; + break; + case "TYPE_PROTOS": + case 14: + message.type = 14; + break; + } + if (object.f != null) + message.f = Number(object.f); + if (object.i != null) + if ($util.Long) + (message.i = $util.Long.fromValue(object.i)).unsigned = false; + else if (typeof object.i === "string") + message.i = parseInt(object.i, 10); + else if (typeof object.i === "number") + message.i = object.i; + else if (typeof object.i === "object") + message.i = new $util.LongBits(object.i.low >>> 0, object.i.high >>> 0).toNumber(); + if (object.s != null) + if (typeof object.s === "string") + $util.base64.decode(object.s, message.s = $util.newBuffer($util.base64.length(object.s)), 0); + else if (object.s.length >= 0) + message.s = object.s; + if (object.t != null) { + if (typeof object.t !== "object") + throw TypeError(".onnx.AttributeProto.t: object expected"); + message.t = $root.onnx.TensorProto.fromObject(object.t); + } + if (object.g != null) { + if (typeof object.g !== "object") + throw TypeError(".onnx.AttributeProto.g: object expected"); + message.g = $root.onnx.GraphProto.fromObject(object.g); + } + if (object.sparseTensor != null) { + if (typeof object.sparseTensor !== "object") + throw TypeError(".onnx.AttributeProto.sparseTensor: object expected"); + message.sparseTensor = $root.onnx.SparseTensorProto.fromObject(object.sparseTensor); + } + if (object.tp != null) { + if (typeof object.tp !== "object") + throw TypeError(".onnx.AttributeProto.tp: object expected"); + message.tp = $root.onnx.TypeProto.fromObject(object.tp); + } + if (object.floats) { + if (!Array.isArray(object.floats)) + throw TypeError(".onnx.AttributeProto.floats: array expected"); + message.floats = []; + for (var i = 0; i < object.floats.length; ++i) + message.floats[i] = Number(object.floats[i]); + } + if (object.ints) { + if (!Array.isArray(object.ints)) + throw TypeError(".onnx.AttributeProto.ints: array expected"); + message.ints = []; + for (var i = 0; i < object.ints.length; ++i) + if ($util.Long) + (message.ints[i] = $util.Long.fromValue(object.ints[i])).unsigned = false; + else if (typeof object.ints[i] === "string") + message.ints[i] = parseInt(object.ints[i], 10); + else if (typeof object.ints[i] === "number") + message.ints[i] = object.ints[i]; + else if (typeof object.ints[i] === "object") + message.ints[i] = new $util.LongBits(object.ints[i].low >>> 0, object.ints[i].high >>> 0).toNumber(); + } + if (object.strings) { + if (!Array.isArray(object.strings)) + throw TypeError(".onnx.AttributeProto.strings: array expected"); + message.strings = []; + for (var i = 0; i < object.strings.length; ++i) + if (typeof object.strings[i] === "string") + $util.base64.decode(object.strings[i], message.strings[i] = $util.newBuffer($util.base64.length(object.strings[i])), 0); + else if (object.strings[i].length >= 0) + message.strings[i] = object.strings[i]; + } + if (object.tensors) { + if (!Array.isArray(object.tensors)) + throw TypeError(".onnx.AttributeProto.tensors: array expected"); + message.tensors = []; + for (var i = 0; i < object.tensors.length; ++i) { + if (typeof object.tensors[i] !== "object") + throw TypeError(".onnx.AttributeProto.tensors: object expected"); + message.tensors[i] = $root.onnx.TensorProto.fromObject(object.tensors[i]); + } + } + if (object.graphs) { + if (!Array.isArray(object.graphs)) + throw TypeError(".onnx.AttributeProto.graphs: array expected"); + message.graphs = []; + for (var i = 0; i < object.graphs.length; ++i) { + if (typeof object.graphs[i] !== "object") + throw TypeError(".onnx.AttributeProto.graphs: object expected"); + message.graphs[i] = $root.onnx.GraphProto.fromObject(object.graphs[i]); + } + } + if (object.sparseTensors) { + if (!Array.isArray(object.sparseTensors)) + throw TypeError(".onnx.AttributeProto.sparseTensors: array expected"); + message.sparseTensors = []; + for (var i = 0; i < object.sparseTensors.length; ++i) { + if (typeof object.sparseTensors[i] !== "object") + throw TypeError(".onnx.AttributeProto.sparseTensors: object expected"); + message.sparseTensors[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseTensors[i]); + } + } + if (object.typeProtos) { + if (!Array.isArray(object.typeProtos)) + throw TypeError(".onnx.AttributeProto.typeProtos: array expected"); + message.typeProtos = []; + for (var i = 0; i < object.typeProtos.length; ++i) { + if (typeof object.typeProtos[i] !== "object") + throw TypeError(".onnx.AttributeProto.typeProtos: object expected"); + message.typeProtos[i] = $root.onnx.TypeProto.fromObject(object.typeProtos[i]); + } + } + return message; + }; + + /** + * Creates a plain object from an AttributeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.AttributeProto + * @static + * @param {onnx.AttributeProto} message AttributeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + AttributeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.floats = []; + object.ints = []; + object.strings = []; + object.tensors = []; + object.graphs = []; + object.typeProtos = []; + object.sparseTensors = []; + } + if (options.defaults) { + object.name = ""; + object.f = 0; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.i = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.i = options.longs === String ? "0" : 0; + if (options.bytes === String) + object.s = ""; + else { + object.s = []; + if (options.bytes !== Array) + object.s = $util.newBuffer(object.s); + } + object.t = null; + object.g = null; + object.docString = ""; + object.tp = null; + object.type = options.enums === String ? "UNDEFINED" : 0; + object.refAttrName = ""; + object.sparseTensor = null; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.f != null && message.hasOwnProperty("f")) + object.f = options.json && !isFinite(message.f) ? String(message.f) : message.f; + if (message.i != null && message.hasOwnProperty("i")) + if (typeof message.i === "number") + object.i = options.longs === String ? String(message.i) : message.i; + else + object.i = options.longs === String ? $util.Long.prototype.toString.call(message.i) : options.longs === Number ? new $util.LongBits(message.i.low >>> 0, message.i.high >>> 0).toNumber() : message.i; + if (message.s != null && message.hasOwnProperty("s")) + object.s = options.bytes === String ? $util.base64.encode(message.s, 0, message.s.length) : options.bytes === Array ? Array.prototype.slice.call(message.s) : message.s; + if (message.t != null && message.hasOwnProperty("t")) + object.t = $root.onnx.TensorProto.toObject(message.t, options); + if (message.g != null && message.hasOwnProperty("g")) + object.g = $root.onnx.GraphProto.toObject(message.g, options); + if (message.floats && message.floats.length) { + object.floats = []; + for (var j = 0; j < message.floats.length; ++j) + object.floats[j] = options.json && !isFinite(message.floats[j]) ? String(message.floats[j]) : message.floats[j]; + } + if (message.ints && message.ints.length) { + object.ints = []; + for (var j = 0; j < message.ints.length; ++j) + if (typeof message.ints[j] === "number") + object.ints[j] = options.longs === String ? String(message.ints[j]) : message.ints[j]; + else + object.ints[j] = options.longs === String ? $util.Long.prototype.toString.call(message.ints[j]) : options.longs === Number ? new $util.LongBits(message.ints[j].low >>> 0, message.ints[j].high >>> 0).toNumber() : message.ints[j]; + } + if (message.strings && message.strings.length) { + object.strings = []; + for (var j = 0; j < message.strings.length; ++j) + object.strings[j] = options.bytes === String ? $util.base64.encode(message.strings[j], 0, message.strings[j].length) : options.bytes === Array ? Array.prototype.slice.call(message.strings[j]) : message.strings[j]; + } + if (message.tensors && message.tensors.length) { + object.tensors = []; + for (var j = 0; j < message.tensors.length; ++j) + object.tensors[j] = $root.onnx.TensorProto.toObject(message.tensors[j], options); + } + if (message.graphs && message.graphs.length) { + object.graphs = []; + for (var j = 0; j < message.graphs.length; ++j) + object.graphs[j] = $root.onnx.GraphProto.toObject(message.graphs[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.tp != null && message.hasOwnProperty("tp")) + object.tp = $root.onnx.TypeProto.toObject(message.tp, options); + if (message.typeProtos && message.typeProtos.length) { + object.typeProtos = []; + for (var j = 0; j < message.typeProtos.length; ++j) + object.typeProtos[j] = $root.onnx.TypeProto.toObject(message.typeProtos[j], options); + } + if (message.type != null && message.hasOwnProperty("type")) + object.type = options.enums === String ? $root.onnx.AttributeProto.AttributeType[message.type] === undefined ? message.type : $root.onnx.AttributeProto.AttributeType[message.type] : message.type; + if (message.refAttrName != null && message.hasOwnProperty("refAttrName")) + object.refAttrName = message.refAttrName; + if (message.sparseTensor != null && message.hasOwnProperty("sparseTensor")) + object.sparseTensor = $root.onnx.SparseTensorProto.toObject(message.sparseTensor, options); + if (message.sparseTensors && message.sparseTensors.length) { + object.sparseTensors = []; + for (var j = 0; j < message.sparseTensors.length; ++j) + object.sparseTensors[j] = $root.onnx.SparseTensorProto.toObject(message.sparseTensors[j], options); + } + return object; + }; + + /** + * Converts this AttributeProto to JSON. + * @function toJSON + * @memberof onnx.AttributeProto + * @instance + * @returns {Object.} JSON object + */ + AttributeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for AttributeProto + * @function getTypeUrl + * @memberof onnx.AttributeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + AttributeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.AttributeProto"; + }; + + /** + * AttributeType enum. + * @name onnx.AttributeProto.AttributeType + * @enum {number} + * @property {number} UNDEFINED=0 UNDEFINED value + * @property {number} FLOAT=1 FLOAT value + * @property {number} INT=2 INT value + * @property {number} STRING=3 STRING value + * @property {number} TENSOR=4 TENSOR value + * @property {number} GRAPH=5 GRAPH value + * @property {number} SPARSE_TENSOR=11 SPARSE_TENSOR value + * @property {number} TYPE_PROTO=13 TYPE_PROTO value + * @property {number} FLOATS=6 FLOATS value + * @property {number} INTS=7 INTS value + * @property {number} STRINGS=8 STRINGS value + * @property {number} TENSORS=9 TENSORS value + * @property {number} GRAPHS=10 GRAPHS value + * @property {number} SPARSE_TENSORS=12 SPARSE_TENSORS value + * @property {number} TYPE_PROTOS=14 TYPE_PROTOS value + */ + AttributeProto.AttributeType = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "UNDEFINED"] = 0; + values[valuesById[1] = "FLOAT"] = 1; + values[valuesById[2] = "INT"] = 2; + values[valuesById[3] = "STRING"] = 3; + values[valuesById[4] = "TENSOR"] = 4; + values[valuesById[5] = "GRAPH"] = 5; + values[valuesById[11] = "SPARSE_TENSOR"] = 11; + values[valuesById[13] = "TYPE_PROTO"] = 13; + values[valuesById[6] = "FLOATS"] = 6; + values[valuesById[7] = "INTS"] = 7; + values[valuesById[8] = "STRINGS"] = 8; + values[valuesById[9] = "TENSORS"] = 9; + values[valuesById[10] = "GRAPHS"] = 10; + values[valuesById[12] = "SPARSE_TENSORS"] = 12; + values[valuesById[14] = "TYPE_PROTOS"] = 14; + return values; + })(); + + return AttributeProto; + })(); + + onnx.ValueInfoProto = (function() { + + /** + * Properties of a ValueInfoProto. + * @memberof onnx + * @interface IValueInfoProto + * @property {string|null} [name] ValueInfoProto name + * @property {onnx.ITypeProto|null} [type] ValueInfoProto type + * @property {string|null} [docString] ValueInfoProto docString + */ + + /** + * Constructs a new ValueInfoProto. + * @memberof onnx + * @classdesc Represents a ValueInfoProto. + * @implements IValueInfoProto + * @constructor + * @param {onnx.IValueInfoProto=} [properties] Properties to set + */ + function ValueInfoProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * ValueInfoProto name. + * @member {string} name + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.name = ""; + + /** + * ValueInfoProto type. + * @member {onnx.ITypeProto|null|undefined} type + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.type = null; + + /** + * ValueInfoProto docString. + * @member {string} docString + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.docString = ""; + + /** + * Creates a new ValueInfoProto instance using the specified properties. + * @function create + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto=} [properties] Properties to set + * @returns {onnx.ValueInfoProto} ValueInfoProto instance + */ + ValueInfoProto.create = function create(properties) { + return new ValueInfoProto(properties); + }; + + /** + * Encodes the specified ValueInfoProto message. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. + * @function encode + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ValueInfoProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); + if (message.type != null && Object.hasOwnProperty.call(message, "type")) + $root.onnx.TypeProto.encode(message.type, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.docString); + return writer; + }; + + /** + * Encodes the specified ValueInfoProto message, length delimited. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ValueInfoProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.ValueInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.ValueInfoProto} ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ValueInfoProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.ValueInfoProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 2: { + message.type = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + case 3: { + message.docString = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.ValueInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.ValueInfoProto} ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ValueInfoProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a ValueInfoProto message. + * @function verify + * @memberof onnx.ValueInfoProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + ValueInfoProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.type != null && message.hasOwnProperty("type")) { + var error = $root.onnx.TypeProto.verify(message.type); + if (error) + return "type." + error; + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + return null; + }; + + /** + * Creates a ValueInfoProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.ValueInfoProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.ValueInfoProto} ValueInfoProto + */ + ValueInfoProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.ValueInfoProto) + return object; + var message = new $root.onnx.ValueInfoProto(); + if (object.name != null) + message.name = String(object.name); + if (object.type != null) { + if (typeof object.type !== "object") + throw TypeError(".onnx.ValueInfoProto.type: object expected"); + message.type = $root.onnx.TypeProto.fromObject(object.type); + } + if (object.docString != null) + message.docString = String(object.docString); + return message; + }; + + /** + * Creates a plain object from a ValueInfoProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.ValueInfoProto} message ValueInfoProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + ValueInfoProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.name = ""; + object.type = null; + object.docString = ""; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.type != null && message.hasOwnProperty("type")) + object.type = $root.onnx.TypeProto.toObject(message.type, options); + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + return object; + }; + + /** + * Converts this ValueInfoProto to JSON. + * @function toJSON + * @memberof onnx.ValueInfoProto + * @instance + * @returns {Object.} JSON object + */ + ValueInfoProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for ValueInfoProto + * @function getTypeUrl + * @memberof onnx.ValueInfoProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + ValueInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.ValueInfoProto"; + }; + + return ValueInfoProto; + })(); + + onnx.NodeProto = (function() { + + /** + * Properties of a NodeProto. + * @memberof onnx + * @interface INodeProto + * @property {Array.|null} [input] NodeProto input + * @property {Array.|null} [output] NodeProto output + * @property {string|null} [name] NodeProto name + * @property {string|null} [opType] NodeProto opType + * @property {string|null} [domain] NodeProto domain + * @property {Array.|null} [attribute] NodeProto attribute + * @property {string|null} [docString] NodeProto docString + */ + + /** + * Constructs a new NodeProto. + * @memberof onnx + * @classdesc Represents a NodeProto. + * @implements INodeProto + * @constructor + * @param {onnx.INodeProto=} [properties] Properties to set + */ + function NodeProto(properties) { + this.input = []; + this.output = []; + this.attribute = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * NodeProto input. + * @member {Array.} input + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.input = $util.emptyArray; + + /** + * NodeProto output. + * @member {Array.} output + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.output = $util.emptyArray; + + /** + * NodeProto name. + * @member {string} name + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.name = ""; + + /** + * NodeProto opType. + * @member {string} opType + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.opType = ""; + + /** + * NodeProto domain. + * @member {string} domain + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.domain = ""; + + /** + * NodeProto attribute. + * @member {Array.} attribute + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.attribute = $util.emptyArray; + + /** + * NodeProto docString. + * @member {string} docString + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.docString = ""; + + /** + * Creates a new NodeProto instance using the specified properties. + * @function create + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto=} [properties] Properties to set + * @returns {onnx.NodeProto} NodeProto instance + */ + NodeProto.create = function create(properties) { + return new NodeProto(properties); + }; + + /** + * Encodes the specified NodeProto message. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @function encode + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto} message NodeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + NodeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.input[i]); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.output[i]); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.name); + if (message.opType != null && Object.hasOwnProperty.call(message, "opType")) + writer.uint32(/* id 4, wireType 2 =*/34).string(message.opType); + if (message.attribute != null && message.attribute.length) + for (var i = 0; i < message.attribute.length; ++i) + $root.onnx.AttributeProto.encode(message.attribute[i], writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.docString); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 7, wireType 2 =*/58).string(message.domain); + return writer; + }; + + /** + * Encodes the specified NodeProto message, length delimited. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto} message NodeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + NodeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a NodeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.NodeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.NodeProto} NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + NodeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.NodeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.input && message.input.length)) + message.input = []; + message.input.push(reader.string()); + break; + } + case 2: { + if (!(message.output && message.output.length)) + message.output = []; + message.output.push(reader.string()); + break; + } + case 3: { + message.name = reader.string(); + break; + } + case 4: { + message.opType = reader.string(); + break; + } + case 7: { + message.domain = reader.string(); + break; + } + case 5: { + if (!(message.attribute && message.attribute.length)) + message.attribute = []; + message.attribute.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); + break; + } + case 6: { + message.docString = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a NodeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.NodeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.NodeProto} NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + NodeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a NodeProto message. + * @function verify + * @memberof onnx.NodeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + NodeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) + return "input: array expected"; + for (var i = 0; i < message.input.length; ++i) + if (!$util.isString(message.input[i])) + return "input: string[] expected"; + } + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) + return "output: array expected"; + for (var i = 0; i < message.output.length; ++i) + if (!$util.isString(message.output[i])) + return "output: string[] expected"; + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.opType != null && message.hasOwnProperty("opType")) + if (!$util.isString(message.opType)) + return "opType: string expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + if (message.attribute != null && message.hasOwnProperty("attribute")) { + if (!Array.isArray(message.attribute)) + return "attribute: array expected"; + for (var i = 0; i < message.attribute.length; ++i) { + var error = $root.onnx.AttributeProto.verify(message.attribute[i]); + if (error) + return "attribute." + error; + } + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + return null; + }; + + /** + * Creates a NodeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.NodeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.NodeProto} NodeProto + */ + NodeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.NodeProto) + return object; + var message = new $root.onnx.NodeProto(); + if (object.input) { + if (!Array.isArray(object.input)) + throw TypeError(".onnx.NodeProto.input: array expected"); + message.input = []; + for (var i = 0; i < object.input.length; ++i) + message.input[i] = String(object.input[i]); + } + if (object.output) { + if (!Array.isArray(object.output)) + throw TypeError(".onnx.NodeProto.output: array expected"); + message.output = []; + for (var i = 0; i < object.output.length; ++i) + message.output[i] = String(object.output[i]); + } + if (object.name != null) + message.name = String(object.name); + if (object.opType != null) + message.opType = String(object.opType); + if (object.domain != null) + message.domain = String(object.domain); + if (object.attribute) { + if (!Array.isArray(object.attribute)) + throw TypeError(".onnx.NodeProto.attribute: array expected"); + message.attribute = []; + for (var i = 0; i < object.attribute.length; ++i) { + if (typeof object.attribute[i] !== "object") + throw TypeError(".onnx.NodeProto.attribute: object expected"); + message.attribute[i] = $root.onnx.AttributeProto.fromObject(object.attribute[i]); + } + } + if (object.docString != null) + message.docString = String(object.docString); + return message; + }; + + /** + * Creates a plain object from a NodeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.NodeProto + * @static + * @param {onnx.NodeProto} message NodeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + NodeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.input = []; + object.output = []; + object.attribute = []; + } + if (options.defaults) { + object.name = ""; + object.opType = ""; + object.docString = ""; + object.domain = ""; + } + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = message.input[j]; + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = message.output[j]; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.opType != null && message.hasOwnProperty("opType")) + object.opType = message.opType; + if (message.attribute && message.attribute.length) { + object.attribute = []; + for (var j = 0; j < message.attribute.length; ++j) + object.attribute[j] = $root.onnx.AttributeProto.toObject(message.attribute[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + return object; + }; + + /** + * Converts this NodeProto to JSON. + * @function toJSON + * @memberof onnx.NodeProto + * @instance + * @returns {Object.} JSON object + */ + NodeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for NodeProto + * @function getTypeUrl + * @memberof onnx.NodeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + NodeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.NodeProto"; + }; + + return NodeProto; + })(); + + onnx.TrainingInfoProto = (function() { + + /** + * Properties of a TrainingInfoProto. + * @memberof onnx + * @interface ITrainingInfoProto + * @property {onnx.IGraphProto|null} [initialization] TrainingInfoProto initialization + * @property {onnx.IGraphProto|null} [algorithm] TrainingInfoProto algorithm + * @property {Array.|null} [initializationBinding] TrainingInfoProto initializationBinding + * @property {Array.|null} [updateBinding] TrainingInfoProto updateBinding + */ + + /** + * Constructs a new TrainingInfoProto. + * @memberof onnx + * @classdesc Represents a TrainingInfoProto. + * @implements ITrainingInfoProto + * @constructor + * @param {onnx.ITrainingInfoProto=} [properties] Properties to set + */ + function TrainingInfoProto(properties) { + this.initializationBinding = []; + this.updateBinding = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TrainingInfoProto initialization. + * @member {onnx.IGraphProto|null|undefined} initialization + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.initialization = null; + + /** + * TrainingInfoProto algorithm. + * @member {onnx.IGraphProto|null|undefined} algorithm + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.algorithm = null; + + /** + * TrainingInfoProto initializationBinding. + * @member {Array.} initializationBinding + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.initializationBinding = $util.emptyArray; + + /** + * TrainingInfoProto updateBinding. + * @member {Array.} updateBinding + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.updateBinding = $util.emptyArray; + + /** + * Creates a new TrainingInfoProto instance using the specified properties. + * @function create + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto=} [properties] Properties to set + * @returns {onnx.TrainingInfoProto} TrainingInfoProto instance + */ + TrainingInfoProto.create = function create(properties) { + return new TrainingInfoProto(properties); + }; + + /** + * Encodes the specified TrainingInfoProto message. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. + * @function encode + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TrainingInfoProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.initialization != null && Object.hasOwnProperty.call(message, "initialization")) + $root.onnx.GraphProto.encode(message.initialization, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.algorithm != null && Object.hasOwnProperty.call(message, "algorithm")) + $root.onnx.GraphProto.encode(message.algorithm, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.initializationBinding != null && message.initializationBinding.length) + for (var i = 0; i < message.initializationBinding.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.initializationBinding[i], writer.uint32(/* id 3, wireType 2 =*/26).fork()).ldelim(); + if (message.updateBinding != null && message.updateBinding.length) + for (var i = 0; i < message.updateBinding.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.updateBinding[i], writer.uint32(/* id 4, wireType 2 =*/34).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TrainingInfoProto message, length delimited. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TrainingInfoProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TrainingInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TrainingInfoProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TrainingInfoProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.initialization = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 2: { + message.algorithm = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 3: { + if (!(message.initializationBinding && message.initializationBinding.length)) + message.initializationBinding = []; + message.initializationBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 4: { + if (!(message.updateBinding && message.updateBinding.length)) + message.updateBinding = []; + message.updateBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TrainingInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TrainingInfoProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TrainingInfoProto message. + * @function verify + * @memberof onnx.TrainingInfoProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TrainingInfoProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.initialization != null && message.hasOwnProperty("initialization")) { + var error = $root.onnx.GraphProto.verify(message.initialization); + if (error) + return "initialization." + error; + } + if (message.algorithm != null && message.hasOwnProperty("algorithm")) { + var error = $root.onnx.GraphProto.verify(message.algorithm); + if (error) + return "algorithm." + error; + } + if (message.initializationBinding != null && message.hasOwnProperty("initializationBinding")) { + if (!Array.isArray(message.initializationBinding)) + return "initializationBinding: array expected"; + for (var i = 0; i < message.initializationBinding.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.initializationBinding[i]); + if (error) + return "initializationBinding." + error; + } + } + if (message.updateBinding != null && message.hasOwnProperty("updateBinding")) { + if (!Array.isArray(message.updateBinding)) + return "updateBinding: array expected"; + for (var i = 0; i < message.updateBinding.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.updateBinding[i]); + if (error) + return "updateBinding." + error; + } + } + return null; + }; + + /** + * Creates a TrainingInfoProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TrainingInfoProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + */ + TrainingInfoProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TrainingInfoProto) + return object; + var message = new $root.onnx.TrainingInfoProto(); + if (object.initialization != null) { + if (typeof object.initialization !== "object") + throw TypeError(".onnx.TrainingInfoProto.initialization: object expected"); + message.initialization = $root.onnx.GraphProto.fromObject(object.initialization); + } + if (object.algorithm != null) { + if (typeof object.algorithm !== "object") + throw TypeError(".onnx.TrainingInfoProto.algorithm: object expected"); + message.algorithm = $root.onnx.GraphProto.fromObject(object.algorithm); + } + if (object.initializationBinding) { + if (!Array.isArray(object.initializationBinding)) + throw TypeError(".onnx.TrainingInfoProto.initializationBinding: array expected"); + message.initializationBinding = []; + for (var i = 0; i < object.initializationBinding.length; ++i) { + if (typeof object.initializationBinding[i] !== "object") + throw TypeError(".onnx.TrainingInfoProto.initializationBinding: object expected"); + message.initializationBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.initializationBinding[i]); + } + } + if (object.updateBinding) { + if (!Array.isArray(object.updateBinding)) + throw TypeError(".onnx.TrainingInfoProto.updateBinding: array expected"); + message.updateBinding = []; + for (var i = 0; i < object.updateBinding.length; ++i) { + if (typeof object.updateBinding[i] !== "object") + throw TypeError(".onnx.TrainingInfoProto.updateBinding: object expected"); + message.updateBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.updateBinding[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TrainingInfoProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.TrainingInfoProto} message TrainingInfoProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TrainingInfoProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.initializationBinding = []; + object.updateBinding = []; + } + if (options.defaults) { + object.initialization = null; + object.algorithm = null; + } + if (message.initialization != null && message.hasOwnProperty("initialization")) + object.initialization = $root.onnx.GraphProto.toObject(message.initialization, options); + if (message.algorithm != null && message.hasOwnProperty("algorithm")) + object.algorithm = $root.onnx.GraphProto.toObject(message.algorithm, options); + if (message.initializationBinding && message.initializationBinding.length) { + object.initializationBinding = []; + for (var j = 0; j < message.initializationBinding.length; ++j) + object.initializationBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.initializationBinding[j], options); + } + if (message.updateBinding && message.updateBinding.length) { + object.updateBinding = []; + for (var j = 0; j < message.updateBinding.length; ++j) + object.updateBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.updateBinding[j], options); + } + return object; + }; + + /** + * Converts this TrainingInfoProto to JSON. + * @function toJSON + * @memberof onnx.TrainingInfoProto + * @instance + * @returns {Object.} JSON object + */ + TrainingInfoProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TrainingInfoProto + * @function getTypeUrl + * @memberof onnx.TrainingInfoProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TrainingInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TrainingInfoProto"; + }; + + return TrainingInfoProto; + })(); + + onnx.ModelProto = (function() { + + /** + * Properties of a ModelProto. + * @memberof onnx + * @interface IModelProto + * @property {number|Long|null} [irVersion] ModelProto irVersion + * @property {Array.|null} [opsetImport] ModelProto opsetImport + * @property {string|null} [producerName] ModelProto producerName + * @property {string|null} [producerVersion] ModelProto producerVersion + * @property {string|null} [domain] ModelProto domain + * @property {number|Long|null} [modelVersion] ModelProto modelVersion + * @property {string|null} [docString] ModelProto docString + * @property {onnx.IGraphProto|null} [graph] ModelProto graph + * @property {Array.|null} [metadataProps] ModelProto metadataProps + * @property {Array.|null} [trainingInfo] ModelProto trainingInfo + * @property {Array.|null} [functions] ModelProto functions + */ + + /** + * Constructs a new ModelProto. + * @memberof onnx + * @classdesc Represents a ModelProto. + * @implements IModelProto + * @constructor + * @param {onnx.IModelProto=} [properties] Properties to set + */ + function ModelProto(properties) { + this.opsetImport = []; + this.metadataProps = []; + this.trainingInfo = []; + this.functions = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * ModelProto irVersion. + * @member {number|Long} irVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.irVersion = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * ModelProto opsetImport. + * @member {Array.} opsetImport + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.opsetImport = $util.emptyArray; + + /** + * ModelProto producerName. + * @member {string} producerName + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.producerName = ""; + + /** + * ModelProto producerVersion. + * @member {string} producerVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.producerVersion = ""; + + /** + * ModelProto domain. + * @member {string} domain + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.domain = ""; + + /** + * ModelProto modelVersion. + * @member {number|Long} modelVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.modelVersion = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * ModelProto docString. + * @member {string} docString + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.docString = ""; + + /** + * ModelProto graph. + * @member {onnx.IGraphProto|null|undefined} graph + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.graph = null; + + /** + * ModelProto metadataProps. + * @member {Array.} metadataProps + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.metadataProps = $util.emptyArray; + + /** + * ModelProto trainingInfo. + * @member {Array.} trainingInfo + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.trainingInfo = $util.emptyArray; + + /** + * ModelProto functions. + * @member {Array.} functions + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.functions = $util.emptyArray; + + /** + * Creates a new ModelProto instance using the specified properties. + * @function create + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto=} [properties] Properties to set + * @returns {onnx.ModelProto} ModelProto instance + */ + ModelProto.create = function create(properties) { + return new ModelProto(properties); + }; + + /** + * Encodes the specified ModelProto message. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @function encode + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto} message ModelProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ModelProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.irVersion != null && Object.hasOwnProperty.call(message, "irVersion")) + writer.uint32(/* id 1, wireType 0 =*/8).int64(message.irVersion); + if (message.producerName != null && Object.hasOwnProperty.call(message, "producerName")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.producerName); + if (message.producerVersion != null && Object.hasOwnProperty.call(message, "producerVersion")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.producerVersion); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 4, wireType 2 =*/34).string(message.domain); + if (message.modelVersion != null && Object.hasOwnProperty.call(message, "modelVersion")) + writer.uint32(/* id 5, wireType 0 =*/40).int64(message.modelVersion); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.docString); + if (message.graph != null && Object.hasOwnProperty.call(message, "graph")) + $root.onnx.GraphProto.encode(message.graph, writer.uint32(/* id 7, wireType 2 =*/58).fork()).ldelim(); + if (message.opsetImport != null && message.opsetImport.length) + for (var i = 0; i < message.opsetImport.length; ++i) + $root.onnx.OperatorSetIdProto.encode(message.opsetImport[i], writer.uint32(/* id 8, wireType 2 =*/66).fork()).ldelim(); + if (message.metadataProps != null && message.metadataProps.length) + for (var i = 0; i < message.metadataProps.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.metadataProps[i], writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); + if (message.trainingInfo != null && message.trainingInfo.length) + for (var i = 0; i < message.trainingInfo.length; ++i) + $root.onnx.TrainingInfoProto.encode(message.trainingInfo[i], writer.uint32(/* id 20, wireType 2 =*/162).fork()).ldelim(); + if (message.functions != null && message.functions.length) + for (var i = 0; i < message.functions.length; ++i) + $root.onnx.FunctionProto.encode(message.functions[i], writer.uint32(/* id 25, wireType 2 =*/202).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified ModelProto message, length delimited. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto} message ModelProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ModelProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a ModelProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.ModelProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.ModelProto} ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ModelProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.ModelProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.irVersion = reader.int64(); + break; + } + case 8: { + if (!(message.opsetImport && message.opsetImport.length)) + message.opsetImport = []; + message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); + break; + } + case 2: { + message.producerName = reader.string(); + break; + } + case 3: { + message.producerVersion = reader.string(); + break; + } + case 4: { + message.domain = reader.string(); + break; + } + case 5: { + message.modelVersion = reader.int64(); + break; + } + case 6: { + message.docString = reader.string(); + break; + } + case 7: { + message.graph = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 14: { + if (!(message.metadataProps && message.metadataProps.length)) + message.metadataProps = []; + message.metadataProps.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 20: { + if (!(message.trainingInfo && message.trainingInfo.length)) + message.trainingInfo = []; + message.trainingInfo.push($root.onnx.TrainingInfoProto.decode(reader, reader.uint32())); + break; + } + case 25: { + if (!(message.functions && message.functions.length)) + message.functions = []; + message.functions.push($root.onnx.FunctionProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a ModelProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.ModelProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.ModelProto} ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ModelProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a ModelProto message. + * @function verify + * @memberof onnx.ModelProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + ModelProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.irVersion != null && message.hasOwnProperty("irVersion")) + if (!$util.isInteger(message.irVersion) && !(message.irVersion && $util.isInteger(message.irVersion.low) && $util.isInteger(message.irVersion.high))) + return "irVersion: integer|Long expected"; + if (message.opsetImport != null && message.hasOwnProperty("opsetImport")) { + if (!Array.isArray(message.opsetImport)) + return "opsetImport: array expected"; + for (var i = 0; i < message.opsetImport.length; ++i) { + var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); + if (error) + return "opsetImport." + error; + } + } + if (message.producerName != null && message.hasOwnProperty("producerName")) + if (!$util.isString(message.producerName)) + return "producerName: string expected"; + if (message.producerVersion != null && message.hasOwnProperty("producerVersion")) + if (!$util.isString(message.producerVersion)) + return "producerVersion: string expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + if (message.modelVersion != null && message.hasOwnProperty("modelVersion")) + if (!$util.isInteger(message.modelVersion) && !(message.modelVersion && $util.isInteger(message.modelVersion.low) && $util.isInteger(message.modelVersion.high))) + return "modelVersion: integer|Long expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.graph != null && message.hasOwnProperty("graph")) { + var error = $root.onnx.GraphProto.verify(message.graph); + if (error) + return "graph." + error; + } + if (message.metadataProps != null && message.hasOwnProperty("metadataProps")) { + if (!Array.isArray(message.metadataProps)) + return "metadataProps: array expected"; + for (var i = 0; i < message.metadataProps.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.metadataProps[i]); + if (error) + return "metadataProps." + error; + } + } + if (message.trainingInfo != null && message.hasOwnProperty("trainingInfo")) { + if (!Array.isArray(message.trainingInfo)) + return "trainingInfo: array expected"; + for (var i = 0; i < message.trainingInfo.length; ++i) { + var error = $root.onnx.TrainingInfoProto.verify(message.trainingInfo[i]); + if (error) + return "trainingInfo." + error; + } + } + if (message.functions != null && message.hasOwnProperty("functions")) { + if (!Array.isArray(message.functions)) + return "functions: array expected"; + for (var i = 0; i < message.functions.length; ++i) { + var error = $root.onnx.FunctionProto.verify(message.functions[i]); + if (error) + return "functions." + error; + } + } + return null; + }; + + /** + * Creates a ModelProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.ModelProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.ModelProto} ModelProto + */ + ModelProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.ModelProto) + return object; + var message = new $root.onnx.ModelProto(); + if (object.irVersion != null) + if ($util.Long) + (message.irVersion = $util.Long.fromValue(object.irVersion)).unsigned = false; + else if (typeof object.irVersion === "string") + message.irVersion = parseInt(object.irVersion, 10); + else if (typeof object.irVersion === "number") + message.irVersion = object.irVersion; + else if (typeof object.irVersion === "object") + message.irVersion = new $util.LongBits(object.irVersion.low >>> 0, object.irVersion.high >>> 0).toNumber(); + if (object.opsetImport) { + if (!Array.isArray(object.opsetImport)) + throw TypeError(".onnx.ModelProto.opsetImport: array expected"); + message.opsetImport = []; + for (var i = 0; i < object.opsetImport.length; ++i) { + if (typeof object.opsetImport[i] !== "object") + throw TypeError(".onnx.ModelProto.opsetImport: object expected"); + message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); + } + } + if (object.producerName != null) + message.producerName = String(object.producerName); + if (object.producerVersion != null) + message.producerVersion = String(object.producerVersion); + if (object.domain != null) + message.domain = String(object.domain); + if (object.modelVersion != null) + if ($util.Long) + (message.modelVersion = $util.Long.fromValue(object.modelVersion)).unsigned = false; + else if (typeof object.modelVersion === "string") + message.modelVersion = parseInt(object.modelVersion, 10); + else if (typeof object.modelVersion === "number") + message.modelVersion = object.modelVersion; + else if (typeof object.modelVersion === "object") + message.modelVersion = new $util.LongBits(object.modelVersion.low >>> 0, object.modelVersion.high >>> 0).toNumber(); + if (object.docString != null) + message.docString = String(object.docString); + if (object.graph != null) { + if (typeof object.graph !== "object") + throw TypeError(".onnx.ModelProto.graph: object expected"); + message.graph = $root.onnx.GraphProto.fromObject(object.graph); + } + if (object.metadataProps) { + if (!Array.isArray(object.metadataProps)) + throw TypeError(".onnx.ModelProto.metadataProps: array expected"); + message.metadataProps = []; + for (var i = 0; i < object.metadataProps.length; ++i) { + if (typeof object.metadataProps[i] !== "object") + throw TypeError(".onnx.ModelProto.metadataProps: object expected"); + message.metadataProps[i] = $root.onnx.StringStringEntryProto.fromObject(object.metadataProps[i]); + } + } + if (object.trainingInfo) { + if (!Array.isArray(object.trainingInfo)) + throw TypeError(".onnx.ModelProto.trainingInfo: array expected"); + message.trainingInfo = []; + for (var i = 0; i < object.trainingInfo.length; ++i) { + if (typeof object.trainingInfo[i] !== "object") + throw TypeError(".onnx.ModelProto.trainingInfo: object expected"); + message.trainingInfo[i] = $root.onnx.TrainingInfoProto.fromObject(object.trainingInfo[i]); + } + } + if (object.functions) { + if (!Array.isArray(object.functions)) + throw TypeError(".onnx.ModelProto.functions: array expected"); + message.functions = []; + for (var i = 0; i < object.functions.length; ++i) { + if (typeof object.functions[i] !== "object") + throw TypeError(".onnx.ModelProto.functions: object expected"); + message.functions[i] = $root.onnx.FunctionProto.fromObject(object.functions[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a ModelProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.ModelProto + * @static + * @param {onnx.ModelProto} message ModelProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + ModelProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.opsetImport = []; + object.metadataProps = []; + object.trainingInfo = []; + object.functions = []; + } + if (options.defaults) { + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.irVersion = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.irVersion = options.longs === String ? "0" : 0; + object.producerName = ""; + object.producerVersion = ""; + object.domain = ""; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.modelVersion = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.modelVersion = options.longs === String ? "0" : 0; + object.docString = ""; + object.graph = null; + } + if (message.irVersion != null && message.hasOwnProperty("irVersion")) + if (typeof message.irVersion === "number") + object.irVersion = options.longs === String ? String(message.irVersion) : message.irVersion; + else + object.irVersion = options.longs === String ? $util.Long.prototype.toString.call(message.irVersion) : options.longs === Number ? new $util.LongBits(message.irVersion.low >>> 0, message.irVersion.high >>> 0).toNumber() : message.irVersion; + if (message.producerName != null && message.hasOwnProperty("producerName")) + object.producerName = message.producerName; + if (message.producerVersion != null && message.hasOwnProperty("producerVersion")) + object.producerVersion = message.producerVersion; + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.modelVersion != null && message.hasOwnProperty("modelVersion")) + if (typeof message.modelVersion === "number") + object.modelVersion = options.longs === String ? String(message.modelVersion) : message.modelVersion; + else + object.modelVersion = options.longs === String ? $util.Long.prototype.toString.call(message.modelVersion) : options.longs === Number ? new $util.LongBits(message.modelVersion.low >>> 0, message.modelVersion.high >>> 0).toNumber() : message.modelVersion; + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.graph != null && message.hasOwnProperty("graph")) + object.graph = $root.onnx.GraphProto.toObject(message.graph, options); + if (message.opsetImport && message.opsetImport.length) { + object.opsetImport = []; + for (var j = 0; j < message.opsetImport.length; ++j) + object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); + } + if (message.metadataProps && message.metadataProps.length) { + object.metadataProps = []; + for (var j = 0; j < message.metadataProps.length; ++j) + object.metadataProps[j] = $root.onnx.StringStringEntryProto.toObject(message.metadataProps[j], options); + } + if (message.trainingInfo && message.trainingInfo.length) { + object.trainingInfo = []; + for (var j = 0; j < message.trainingInfo.length; ++j) + object.trainingInfo[j] = $root.onnx.TrainingInfoProto.toObject(message.trainingInfo[j], options); + } + if (message.functions && message.functions.length) { + object.functions = []; + for (var j = 0; j < message.functions.length; ++j) + object.functions[j] = $root.onnx.FunctionProto.toObject(message.functions[j], options); + } + return object; + }; + + /** + * Converts this ModelProto to JSON. + * @function toJSON + * @memberof onnx.ModelProto + * @instance + * @returns {Object.} JSON object + */ + ModelProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for ModelProto + * @function getTypeUrl + * @memberof onnx.ModelProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + ModelProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.ModelProto"; + }; + + return ModelProto; + })(); + + onnx.StringStringEntryProto = (function() { + + /** + * Properties of a StringStringEntryProto. + * @memberof onnx + * @interface IStringStringEntryProto + * @property {string|null} [key] StringStringEntryProto key + * @property {string|null} [value] StringStringEntryProto value + */ + + /** + * Constructs a new StringStringEntryProto. + * @memberof onnx + * @classdesc Represents a StringStringEntryProto. + * @implements IStringStringEntryProto + * @constructor + * @param {onnx.IStringStringEntryProto=} [properties] Properties to set + */ + function StringStringEntryProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * StringStringEntryProto key. + * @member {string} key + * @memberof onnx.StringStringEntryProto + * @instance + */ + StringStringEntryProto.prototype.key = ""; + + /** + * StringStringEntryProto value. + * @member {string} value + * @memberof onnx.StringStringEntryProto + * @instance + */ + StringStringEntryProto.prototype.value = ""; + + /** + * Creates a new StringStringEntryProto instance using the specified properties. + * @function create + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto=} [properties] Properties to set + * @returns {onnx.StringStringEntryProto} StringStringEntryProto instance + */ + StringStringEntryProto.create = function create(properties) { + return new StringStringEntryProto(properties); + }; + + /** + * Encodes the specified StringStringEntryProto message. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. + * @function encode + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + StringStringEntryProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.key != null && Object.hasOwnProperty.call(message, "key")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.key); + if (message.value != null && Object.hasOwnProperty.call(message, "value")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.value); + return writer; + }; + + /** + * Encodes the specified StringStringEntryProto message, length delimited. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + StringStringEntryProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.StringStringEntryProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + StringStringEntryProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.StringStringEntryProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.key = reader.string(); + break; + } + case 2: { + message.value = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.StringStringEntryProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + StringStringEntryProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a StringStringEntryProto message. + * @function verify + * @memberof onnx.StringStringEntryProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + StringStringEntryProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.key != null && message.hasOwnProperty("key")) + if (!$util.isString(message.key)) + return "key: string expected"; + if (message.value != null && message.hasOwnProperty("value")) + if (!$util.isString(message.value)) + return "value: string expected"; + return null; + }; + + /** + * Creates a StringStringEntryProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.StringStringEntryProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + */ + StringStringEntryProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.StringStringEntryProto) + return object; + var message = new $root.onnx.StringStringEntryProto(); + if (object.key != null) + message.key = String(object.key); + if (object.value != null) + message.value = String(object.value); + return message; + }; + + /** + * Creates a plain object from a StringStringEntryProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.StringStringEntryProto} message StringStringEntryProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + StringStringEntryProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.key = ""; + object.value = ""; + } + if (message.key != null && message.hasOwnProperty("key")) + object.key = message.key; + if (message.value != null && message.hasOwnProperty("value")) + object.value = message.value; + return object; + }; + + /** + * Converts this StringStringEntryProto to JSON. + * @function toJSON + * @memberof onnx.StringStringEntryProto + * @instance + * @returns {Object.} JSON object + */ + StringStringEntryProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for StringStringEntryProto + * @function getTypeUrl + * @memberof onnx.StringStringEntryProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + StringStringEntryProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.StringStringEntryProto"; + }; + + return StringStringEntryProto; + })(); + + onnx.TensorAnnotation = (function() { + + /** + * Properties of a TensorAnnotation. + * @memberof onnx + * @interface ITensorAnnotation + * @property {string|null} [tensorName] TensorAnnotation tensorName + * @property {Array.|null} [quantParameterTensorNames] TensorAnnotation quantParameterTensorNames + */ + + /** + * Constructs a new TensorAnnotation. + * @memberof onnx + * @classdesc Represents a TensorAnnotation. + * @implements ITensorAnnotation + * @constructor + * @param {onnx.ITensorAnnotation=} [properties] Properties to set + */ + function TensorAnnotation(properties) { + this.quantParameterTensorNames = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorAnnotation tensorName. + * @member {string} tensorName + * @memberof onnx.TensorAnnotation + * @instance + */ + TensorAnnotation.prototype.tensorName = ""; + + /** + * TensorAnnotation quantParameterTensorNames. + * @member {Array.} quantParameterTensorNames + * @memberof onnx.TensorAnnotation + * @instance + */ + TensorAnnotation.prototype.quantParameterTensorNames = $util.emptyArray; + + /** + * Creates a new TensorAnnotation instance using the specified properties. + * @function create + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation=} [properties] Properties to set + * @returns {onnx.TensorAnnotation} TensorAnnotation instance + */ + TensorAnnotation.create = function create(properties) { + return new TensorAnnotation(properties); + }; + + /** + * Encodes the specified TensorAnnotation message. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. + * @function encode + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorAnnotation.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.tensorName != null && Object.hasOwnProperty.call(message, "tensorName")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.tensorName); + if (message.quantParameterTensorNames != null && message.quantParameterTensorNames.length) + for (var i = 0; i < message.quantParameterTensorNames.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.quantParameterTensorNames[i], writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TensorAnnotation message, length delimited. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorAnnotation.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorAnnotation + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorAnnotation} TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorAnnotation.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorAnnotation(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.tensorName = reader.string(); + break; + } + case 2: { + if (!(message.quantParameterTensorNames && message.quantParameterTensorNames.length)) + message.quantParameterTensorNames = []; + message.quantParameterTensorNames.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorAnnotation + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorAnnotation} TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorAnnotation.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorAnnotation message. + * @function verify + * @memberof onnx.TensorAnnotation + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorAnnotation.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.tensorName != null && message.hasOwnProperty("tensorName")) + if (!$util.isString(message.tensorName)) + return "tensorName: string expected"; + if (message.quantParameterTensorNames != null && message.hasOwnProperty("quantParameterTensorNames")) { + if (!Array.isArray(message.quantParameterTensorNames)) + return "quantParameterTensorNames: array expected"; + for (var i = 0; i < message.quantParameterTensorNames.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.quantParameterTensorNames[i]); + if (error) + return "quantParameterTensorNames." + error; + } + } + return null; + }; + + /** + * Creates a TensorAnnotation message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorAnnotation + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorAnnotation} TensorAnnotation + */ + TensorAnnotation.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorAnnotation) + return object; + var message = new $root.onnx.TensorAnnotation(); + if (object.tensorName != null) + message.tensorName = String(object.tensorName); + if (object.quantParameterTensorNames) { + if (!Array.isArray(object.quantParameterTensorNames)) + throw TypeError(".onnx.TensorAnnotation.quantParameterTensorNames: array expected"); + message.quantParameterTensorNames = []; + for (var i = 0; i < object.quantParameterTensorNames.length; ++i) { + if (typeof object.quantParameterTensorNames[i] !== "object") + throw TypeError(".onnx.TensorAnnotation.quantParameterTensorNames: object expected"); + message.quantParameterTensorNames[i] = $root.onnx.StringStringEntryProto.fromObject(object.quantParameterTensorNames[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TensorAnnotation message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.TensorAnnotation} message TensorAnnotation + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorAnnotation.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) + object.quantParameterTensorNames = []; + if (options.defaults) + object.tensorName = ""; + if (message.tensorName != null && message.hasOwnProperty("tensorName")) + object.tensorName = message.tensorName; + if (message.quantParameterTensorNames && message.quantParameterTensorNames.length) { + object.quantParameterTensorNames = []; + for (var j = 0; j < message.quantParameterTensorNames.length; ++j) + object.quantParameterTensorNames[j] = $root.onnx.StringStringEntryProto.toObject(message.quantParameterTensorNames[j], options); + } + return object; + }; + + /** + * Converts this TensorAnnotation to JSON. + * @function toJSON + * @memberof onnx.TensorAnnotation + * @instance + * @returns {Object.} JSON object + */ + TensorAnnotation.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorAnnotation + * @function getTypeUrl + * @memberof onnx.TensorAnnotation + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorAnnotation.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorAnnotation"; + }; + + return TensorAnnotation; + })(); + + onnx.GraphProto = (function() { + + /** + * Properties of a GraphProto. + * @memberof onnx + * @interface IGraphProto + * @property {Array.|null} [node] GraphProto node + * @property {string|null} [name] GraphProto name + * @property {Array.|null} [initializer] GraphProto initializer + * @property {Array.|null} [sparseInitializer] GraphProto sparseInitializer + * @property {string|null} [docString] GraphProto docString + * @property {Array.|null} [input] GraphProto input + * @property {Array.|null} [output] GraphProto output + * @property {Array.|null} [valueInfo] GraphProto valueInfo + * @property {Array.|null} [quantizationAnnotation] GraphProto quantizationAnnotation + */ + + /** + * Constructs a new GraphProto. + * @memberof onnx + * @classdesc Represents a GraphProto. + * @implements IGraphProto + * @constructor + * @param {onnx.IGraphProto=} [properties] Properties to set + */ + function GraphProto(properties) { + this.node = []; + this.initializer = []; + this.sparseInitializer = []; + this.input = []; + this.output = []; + this.valueInfo = []; + this.quantizationAnnotation = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * GraphProto node. + * @member {Array.} node + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.node = $util.emptyArray; + + /** + * GraphProto name. + * @member {string} name + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.name = ""; + + /** + * GraphProto initializer. + * @member {Array.} initializer + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.initializer = $util.emptyArray; + + /** + * GraphProto sparseInitializer. + * @member {Array.} sparseInitializer + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.sparseInitializer = $util.emptyArray; + + /** + * GraphProto docString. + * @member {string} docString + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.docString = ""; + + /** + * GraphProto input. + * @member {Array.} input + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.input = $util.emptyArray; + + /** + * GraphProto output. + * @member {Array.} output + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.output = $util.emptyArray; + + /** + * GraphProto valueInfo. + * @member {Array.} valueInfo + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.valueInfo = $util.emptyArray; + + /** + * GraphProto quantizationAnnotation. + * @member {Array.} quantizationAnnotation + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.quantizationAnnotation = $util.emptyArray; + + /** + * Creates a new GraphProto instance using the specified properties. + * @function create + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto=} [properties] Properties to set + * @returns {onnx.GraphProto} GraphProto instance + */ + GraphProto.create = function create(properties) { + return new GraphProto(properties); + }; + + /** + * Encodes the specified GraphProto message. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @function encode + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto} message GraphProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + GraphProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.node != null && message.node.length) + for (var i = 0; i < message.node.length; ++i) + $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.name); + if (message.initializer != null && message.initializer.length) + for (var i = 0; i < message.initializer.length; ++i) + $root.onnx.TensorProto.encode(message.initializer[i], writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 10, wireType 2 =*/82).string(message.docString); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + $root.onnx.ValueInfoProto.encode(message.input[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + $root.onnx.ValueInfoProto.encode(message.output[i], writer.uint32(/* id 12, wireType 2 =*/98).fork()).ldelim(); + if (message.valueInfo != null && message.valueInfo.length) + for (var i = 0; i < message.valueInfo.length; ++i) + $root.onnx.ValueInfoProto.encode(message.valueInfo[i], writer.uint32(/* id 13, wireType 2 =*/106).fork()).ldelim(); + if (message.quantizationAnnotation != null && message.quantizationAnnotation.length) + for (var i = 0; i < message.quantizationAnnotation.length; ++i) + $root.onnx.TensorAnnotation.encode(message.quantizationAnnotation[i], writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); + if (message.sparseInitializer != null && message.sparseInitializer.length) + for (var i = 0; i < message.sparseInitializer.length; ++i) + $root.onnx.SparseTensorProto.encode(message.sparseInitializer[i], writer.uint32(/* id 15, wireType 2 =*/122).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified GraphProto message, length delimited. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto} message GraphProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + GraphProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a GraphProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.GraphProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.GraphProto} GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + GraphProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.GraphProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.node && message.node.length)) + message.node = []; + message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); + break; + } + case 2: { + message.name = reader.string(); + break; + } + case 5: { + if (!(message.initializer && message.initializer.length)) + message.initializer = []; + message.initializer.push($root.onnx.TensorProto.decode(reader, reader.uint32())); + break; + } + case 15: { + if (!(message.sparseInitializer && message.sparseInitializer.length)) + message.sparseInitializer = []; + message.sparseInitializer.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); + break; + } + case 10: { + message.docString = reader.string(); + break; + } + case 11: { + if (!(message.input && message.input.length)) + message.input = []; + message.input.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 12: { + if (!(message.output && message.output.length)) + message.output = []; + message.output.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 13: { + if (!(message.valueInfo && message.valueInfo.length)) + message.valueInfo = []; + message.valueInfo.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 14: { + if (!(message.quantizationAnnotation && message.quantizationAnnotation.length)) + message.quantizationAnnotation = []; + message.quantizationAnnotation.push($root.onnx.TensorAnnotation.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a GraphProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.GraphProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.GraphProto} GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + GraphProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a GraphProto message. + * @function verify + * @memberof onnx.GraphProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + GraphProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.node != null && message.hasOwnProperty("node")) { + if (!Array.isArray(message.node)) + return "node: array expected"; + for (var i = 0; i < message.node.length; ++i) { + var error = $root.onnx.NodeProto.verify(message.node[i]); + if (error) + return "node." + error; + } + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.initializer != null && message.hasOwnProperty("initializer")) { + if (!Array.isArray(message.initializer)) + return "initializer: array expected"; + for (var i = 0; i < message.initializer.length; ++i) { + var error = $root.onnx.TensorProto.verify(message.initializer[i]); + if (error) + return "initializer." + error; + } + } + if (message.sparseInitializer != null && message.hasOwnProperty("sparseInitializer")) { + if (!Array.isArray(message.sparseInitializer)) + return "sparseInitializer: array expected"; + for (var i = 0; i < message.sparseInitializer.length; ++i) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseInitializer[i]); + if (error) + return "sparseInitializer." + error; + } + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) + return "input: array expected"; + for (var i = 0; i < message.input.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.input[i]); + if (error) + return "input." + error; + } + } + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) + return "output: array expected"; + for (var i = 0; i < message.output.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.output[i]); + if (error) + return "output." + error; + } + } + if (message.valueInfo != null && message.hasOwnProperty("valueInfo")) { + if (!Array.isArray(message.valueInfo)) + return "valueInfo: array expected"; + for (var i = 0; i < message.valueInfo.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.valueInfo[i]); + if (error) + return "valueInfo." + error; + } + } + if (message.quantizationAnnotation != null && message.hasOwnProperty("quantizationAnnotation")) { + if (!Array.isArray(message.quantizationAnnotation)) + return "quantizationAnnotation: array expected"; + for (var i = 0; i < message.quantizationAnnotation.length; ++i) { + var error = $root.onnx.TensorAnnotation.verify(message.quantizationAnnotation[i]); + if (error) + return "quantizationAnnotation." + error; + } + } + return null; + }; + + /** + * Creates a GraphProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.GraphProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.GraphProto} GraphProto + */ + GraphProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.GraphProto) + return object; + var message = new $root.onnx.GraphProto(); + if (object.node) { + if (!Array.isArray(object.node)) + throw TypeError(".onnx.GraphProto.node: array expected"); + message.node = []; + for (var i = 0; i < object.node.length; ++i) { + if (typeof object.node[i] !== "object") + throw TypeError(".onnx.GraphProto.node: object expected"); + message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); + } + } + if (object.name != null) + message.name = String(object.name); + if (object.initializer) { + if (!Array.isArray(object.initializer)) + throw TypeError(".onnx.GraphProto.initializer: array expected"); + message.initializer = []; + for (var i = 0; i < object.initializer.length; ++i) { + if (typeof object.initializer[i] !== "object") + throw TypeError(".onnx.GraphProto.initializer: object expected"); + message.initializer[i] = $root.onnx.TensorProto.fromObject(object.initializer[i]); + } + } + if (object.sparseInitializer) { + if (!Array.isArray(object.sparseInitializer)) + throw TypeError(".onnx.GraphProto.sparseInitializer: array expected"); + message.sparseInitializer = []; + for (var i = 0; i < object.sparseInitializer.length; ++i) { + if (typeof object.sparseInitializer[i] !== "object") + throw TypeError(".onnx.GraphProto.sparseInitializer: object expected"); + message.sparseInitializer[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseInitializer[i]); + } + } + if (object.docString != null) + message.docString = String(object.docString); + if (object.input) { + if (!Array.isArray(object.input)) + throw TypeError(".onnx.GraphProto.input: array expected"); + message.input = []; + for (var i = 0; i < object.input.length; ++i) { + if (typeof object.input[i] !== "object") + throw TypeError(".onnx.GraphProto.input: object expected"); + message.input[i] = $root.onnx.ValueInfoProto.fromObject(object.input[i]); + } + } + if (object.output) { + if (!Array.isArray(object.output)) + throw TypeError(".onnx.GraphProto.output: array expected"); + message.output = []; + for (var i = 0; i < object.output.length; ++i) { + if (typeof object.output[i] !== "object") + throw TypeError(".onnx.GraphProto.output: object expected"); + message.output[i] = $root.onnx.ValueInfoProto.fromObject(object.output[i]); + } + } + if (object.valueInfo) { + if (!Array.isArray(object.valueInfo)) + throw TypeError(".onnx.GraphProto.valueInfo: array expected"); + message.valueInfo = []; + for (var i = 0; i < object.valueInfo.length; ++i) { + if (typeof object.valueInfo[i] !== "object") + throw TypeError(".onnx.GraphProto.valueInfo: object expected"); + message.valueInfo[i] = $root.onnx.ValueInfoProto.fromObject(object.valueInfo[i]); + } + } + if (object.quantizationAnnotation) { + if (!Array.isArray(object.quantizationAnnotation)) + throw TypeError(".onnx.GraphProto.quantizationAnnotation: array expected"); + message.quantizationAnnotation = []; + for (var i = 0; i < object.quantizationAnnotation.length; ++i) { + if (typeof object.quantizationAnnotation[i] !== "object") + throw TypeError(".onnx.GraphProto.quantizationAnnotation: object expected"); + message.quantizationAnnotation[i] = $root.onnx.TensorAnnotation.fromObject(object.quantizationAnnotation[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a GraphProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.GraphProto + * @static + * @param {onnx.GraphProto} message GraphProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + GraphProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.node = []; + object.initializer = []; + object.input = []; + object.output = []; + object.valueInfo = []; + object.quantizationAnnotation = []; + object.sparseInitializer = []; + } + if (options.defaults) { + object.name = ""; + object.docString = ""; + } + if (message.node && message.node.length) { + object.node = []; + for (var j = 0; j < message.node.length; ++j) + object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.initializer && message.initializer.length) { + object.initializer = []; + for (var j = 0; j < message.initializer.length; ++j) + object.initializer[j] = $root.onnx.TensorProto.toObject(message.initializer[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = $root.onnx.ValueInfoProto.toObject(message.input[j], options); + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = $root.onnx.ValueInfoProto.toObject(message.output[j], options); + } + if (message.valueInfo && message.valueInfo.length) { + object.valueInfo = []; + for (var j = 0; j < message.valueInfo.length; ++j) + object.valueInfo[j] = $root.onnx.ValueInfoProto.toObject(message.valueInfo[j], options); + } + if (message.quantizationAnnotation && message.quantizationAnnotation.length) { + object.quantizationAnnotation = []; + for (var j = 0; j < message.quantizationAnnotation.length; ++j) + object.quantizationAnnotation[j] = $root.onnx.TensorAnnotation.toObject(message.quantizationAnnotation[j], options); + } + if (message.sparseInitializer && message.sparseInitializer.length) { + object.sparseInitializer = []; + for (var j = 0; j < message.sparseInitializer.length; ++j) + object.sparseInitializer[j] = $root.onnx.SparseTensorProto.toObject(message.sparseInitializer[j], options); + } + return object; + }; + + /** + * Converts this GraphProto to JSON. + * @function toJSON + * @memberof onnx.GraphProto + * @instance + * @returns {Object.} JSON object + */ + GraphProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for GraphProto + * @function getTypeUrl + * @memberof onnx.GraphProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + GraphProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.GraphProto"; + }; + + return GraphProto; + })(); + + onnx.TensorProto = (function() { + + /** + * Properties of a TensorProto. + * @memberof onnx + * @interface ITensorProto + * @property {Array.|null} [dims] TensorProto dims + * @property {number|null} [dataType] TensorProto dataType + * @property {onnx.TensorProto.ISegment|null} [segment] TensorProto segment + * @property {Array.|null} [floatData] TensorProto floatData + * @property {Array.|null} [int32Data] TensorProto int32Data + * @property {Array.|null} [stringData] TensorProto stringData + * @property {Array.|null} [int64Data] TensorProto int64Data + * @property {string|null} [name] TensorProto name + * @property {string|null} [docString] TensorProto docString + * @property {Uint8Array|null} [rawData] TensorProto rawData + * @property {Array.|null} [externalData] TensorProto externalData + * @property {onnx.TensorProto.DataLocation|null} [dataLocation] TensorProto dataLocation + * @property {Array.|null} [doubleData] TensorProto doubleData + * @property {Array.|null} [uint64Data] TensorProto uint64Data + */ + + /** + * Constructs a new TensorProto. + * @memberof onnx + * @classdesc Represents a TensorProto. + * @implements ITensorProto + * @constructor + * @param {onnx.ITensorProto=} [properties] Properties to set + */ + function TensorProto(properties) { + this.dims = []; + this.floatData = []; + this.int32Data = []; + this.stringData = []; + this.int64Data = []; + this.externalData = []; + this.doubleData = []; + this.uint64Data = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorProto dims. + * @member {Array.} dims + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dims = $util.emptyArray; + + /** + * TensorProto dataType. + * @member {number} dataType + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dataType = 0; + + /** + * TensorProto segment. + * @member {onnx.TensorProto.ISegment|null|undefined} segment + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.segment = null; + + /** + * TensorProto floatData. + * @member {Array.} floatData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.floatData = $util.emptyArray; + + /** + * TensorProto int32Data. + * @member {Array.} int32Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.int32Data = $util.emptyArray; + + /** + * TensorProto stringData. + * @member {Array.} stringData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.stringData = $util.emptyArray; + + /** + * TensorProto int64Data. + * @member {Array.} int64Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.int64Data = $util.emptyArray; + + /** + * TensorProto name. + * @member {string} name + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.name = ""; + + /** + * TensorProto docString. + * @member {string} docString + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.docString = ""; + + /** + * TensorProto rawData. + * @member {Uint8Array} rawData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.rawData = $util.newBuffer([]); + + /** + * TensorProto externalData. + * @member {Array.} externalData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.externalData = $util.emptyArray; + + /** + * TensorProto dataLocation. + * @member {onnx.TensorProto.DataLocation} dataLocation + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dataLocation = 0; + + /** + * TensorProto doubleData. + * @member {Array.} doubleData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.doubleData = $util.emptyArray; + + /** + * TensorProto uint64Data. + * @member {Array.} uint64Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.uint64Data = $util.emptyArray; + + /** + * Creates a new TensorProto instance using the specified properties. + * @function create + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto=} [properties] Properties to set + * @returns {onnx.TensorProto} TensorProto instance + */ + TensorProto.create = function create(properties) { + return new TensorProto(properties); + }; + + /** + * Encodes the specified TensorProto message. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @function encode + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto} message TensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.dims != null && message.dims.length) { + writer.uint32(/* id 1, wireType 2 =*/10).fork(); + for (var i = 0; i < message.dims.length; ++i) + writer.int64(message.dims[i]); + writer.ldelim(); + } + if (message.dataType != null && Object.hasOwnProperty.call(message, "dataType")) + writer.uint32(/* id 2, wireType 0 =*/16).int32(message.dataType); + if (message.segment != null && Object.hasOwnProperty.call(message, "segment")) + $root.onnx.TensorProto.Segment.encode(message.segment, writer.uint32(/* id 3, wireType 2 =*/26).fork()).ldelim(); + if (message.floatData != null && message.floatData.length) { + writer.uint32(/* id 4, wireType 2 =*/34).fork(); + for (var i = 0; i < message.floatData.length; ++i) + writer.float(message.floatData[i]); + writer.ldelim(); + } + if (message.int32Data != null && message.int32Data.length) { + writer.uint32(/* id 5, wireType 2 =*/42).fork(); + for (var i = 0; i < message.int32Data.length; ++i) + writer.int32(message.int32Data[i]); + writer.ldelim(); + } + if (message.stringData != null && message.stringData.length) + for (var i = 0; i < message.stringData.length; ++i) + writer.uint32(/* id 6, wireType 2 =*/50).bytes(message.stringData[i]); + if (message.int64Data != null && message.int64Data.length) { + writer.uint32(/* id 7, wireType 2 =*/58).fork(); + for (var i = 0; i < message.int64Data.length; ++i) + writer.int64(message.int64Data[i]); + writer.ldelim(); + } + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 8, wireType 2 =*/66).string(message.name); + if (message.rawData != null && Object.hasOwnProperty.call(message, "rawData")) + writer.uint32(/* id 9, wireType 2 =*/74).bytes(message.rawData); + if (message.doubleData != null && message.doubleData.length) { + writer.uint32(/* id 10, wireType 2 =*/82).fork(); + for (var i = 0; i < message.doubleData.length; ++i) + writer.double(message.doubleData[i]); + writer.ldelim(); + } + if (message.uint64Data != null && message.uint64Data.length) { + writer.uint32(/* id 11, wireType 2 =*/90).fork(); + for (var i = 0; i < message.uint64Data.length; ++i) + writer.uint64(message.uint64Data[i]); + writer.ldelim(); + } + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 12, wireType 2 =*/98).string(message.docString); + if (message.externalData != null && message.externalData.length) + for (var i = 0; i < message.externalData.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.externalData[i], writer.uint32(/* id 13, wireType 2 =*/106).fork()).ldelim(); + if (message.dataLocation != null && Object.hasOwnProperty.call(message, "dataLocation")) + writer.uint32(/* id 14, wireType 0 =*/112).int32(message.dataLocation); + return writer; + }; + + /** + * Encodes the specified TensorProto message, length delimited. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto} message TensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorProto} TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.dims && message.dims.length)) + message.dims = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.dims.push(reader.int64()); + } else + message.dims.push(reader.int64()); + break; + } + case 2: { + message.dataType = reader.int32(); + break; + } + case 3: { + message.segment = $root.onnx.TensorProto.Segment.decode(reader, reader.uint32()); + break; + } + case 4: { + if (!(message.floatData && message.floatData.length)) + message.floatData = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.floatData.push(reader.float()); + } else + message.floatData.push(reader.float()); + break; + } + case 5: { + if (!(message.int32Data && message.int32Data.length)) + message.int32Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.int32Data.push(reader.int32()); + } else + message.int32Data.push(reader.int32()); + break; + } + case 6: { + if (!(message.stringData && message.stringData.length)) + message.stringData = []; + message.stringData.push(reader.bytes()); + break; + } + case 7: { + if (!(message.int64Data && message.int64Data.length)) + message.int64Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.int64Data.push(reader.int64()); + } else + message.int64Data.push(reader.int64()); + break; + } + case 8: { + message.name = reader.string(); + break; + } + case 12: { + message.docString = reader.string(); + break; + } + case 9: { + message.rawData = reader.bytes(); + break; + } + case 13: { + if (!(message.externalData && message.externalData.length)) + message.externalData = []; + message.externalData.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 14: { + message.dataLocation = reader.int32(); + break; + } + case 10: { + if (!(message.doubleData && message.doubleData.length)) + message.doubleData = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.doubleData.push(reader.double()); + } else + message.doubleData.push(reader.double()); + break; + } + case 11: { + if (!(message.uint64Data && message.uint64Data.length)) + message.uint64Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.uint64Data.push(reader.uint64()); + } else + message.uint64Data.push(reader.uint64()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorProto} TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorProto message. + * @function verify + * @memberof onnx.TensorProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.dims != null && message.hasOwnProperty("dims")) { + if (!Array.isArray(message.dims)) + return "dims: array expected"; + for (var i = 0; i < message.dims.length; ++i) + if (!$util.isInteger(message.dims[i]) && !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high))) + return "dims: integer|Long[] expected"; + } + if (message.dataType != null && message.hasOwnProperty("dataType")) + if (!$util.isInteger(message.dataType)) + return "dataType: integer expected"; + if (message.segment != null && message.hasOwnProperty("segment")) { + var error = $root.onnx.TensorProto.Segment.verify(message.segment); + if (error) + return "segment." + error; + } + if (message.floatData != null && message.hasOwnProperty("floatData")) { + if (!Array.isArray(message.floatData)) + return "floatData: array expected"; + for (var i = 0; i < message.floatData.length; ++i) + if (typeof message.floatData[i] !== "number") + return "floatData: number[] expected"; + } + if (message.int32Data != null && message.hasOwnProperty("int32Data")) { + if (!Array.isArray(message.int32Data)) + return "int32Data: array expected"; + for (var i = 0; i < message.int32Data.length; ++i) + if (!$util.isInteger(message.int32Data[i])) + return "int32Data: integer[] expected"; + } + if (message.stringData != null && message.hasOwnProperty("stringData")) { + if (!Array.isArray(message.stringData)) + return "stringData: array expected"; + for (var i = 0; i < message.stringData.length; ++i) + if (!(message.stringData[i] && typeof message.stringData[i].length === "number" || $util.isString(message.stringData[i]))) + return "stringData: buffer[] expected"; + } + if (message.int64Data != null && message.hasOwnProperty("int64Data")) { + if (!Array.isArray(message.int64Data)) + return "int64Data: array expected"; + for (var i = 0; i < message.int64Data.length; ++i) + if (!$util.isInteger(message.int64Data[i]) && !(message.int64Data[i] && $util.isInteger(message.int64Data[i].low) && $util.isInteger(message.int64Data[i].high))) + return "int64Data: integer|Long[] expected"; + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.rawData != null && message.hasOwnProperty("rawData")) + if (!(message.rawData && typeof message.rawData.length === "number" || $util.isString(message.rawData))) + return "rawData: buffer expected"; + if (message.externalData != null && message.hasOwnProperty("externalData")) { + if (!Array.isArray(message.externalData)) + return "externalData: array expected"; + for (var i = 0; i < message.externalData.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.externalData[i]); + if (error) + return "externalData." + error; + } + } + if (message.dataLocation != null && message.hasOwnProperty("dataLocation")) + switch (message.dataLocation) { + default: + return "dataLocation: enum value expected"; + case 0: + case 1: + break; + } + if (message.doubleData != null && message.hasOwnProperty("doubleData")) { + if (!Array.isArray(message.doubleData)) + return "doubleData: array expected"; + for (var i = 0; i < message.doubleData.length; ++i) + if (typeof message.doubleData[i] !== "number") + return "doubleData: number[] expected"; + } + if (message.uint64Data != null && message.hasOwnProperty("uint64Data")) { + if (!Array.isArray(message.uint64Data)) + return "uint64Data: array expected"; + for (var i = 0; i < message.uint64Data.length; ++i) + if (!$util.isInteger(message.uint64Data[i]) && !(message.uint64Data[i] && $util.isInteger(message.uint64Data[i].low) && $util.isInteger(message.uint64Data[i].high))) + return "uint64Data: integer|Long[] expected"; + } + return null; + }; + + /** + * Creates a TensorProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorProto} TensorProto + */ + TensorProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorProto) + return object; + var message = new $root.onnx.TensorProto(); + if (object.dims) { + if (!Array.isArray(object.dims)) + throw TypeError(".onnx.TensorProto.dims: array expected"); + message.dims = []; + for (var i = 0; i < object.dims.length; ++i) + if ($util.Long) + (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; + else if (typeof object.dims[i] === "string") + message.dims[i] = parseInt(object.dims[i], 10); + else if (typeof object.dims[i] === "number") + message.dims[i] = object.dims[i]; + else if (typeof object.dims[i] === "object") + message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); + } + if (object.dataType != null) + message.dataType = object.dataType | 0; + if (object.segment != null) { + if (typeof object.segment !== "object") + throw TypeError(".onnx.TensorProto.segment: object expected"); + message.segment = $root.onnx.TensorProto.Segment.fromObject(object.segment); + } + if (object.floatData) { + if (!Array.isArray(object.floatData)) + throw TypeError(".onnx.TensorProto.floatData: array expected"); + message.floatData = []; + for (var i = 0; i < object.floatData.length; ++i) + message.floatData[i] = Number(object.floatData[i]); + } + if (object.int32Data) { + if (!Array.isArray(object.int32Data)) + throw TypeError(".onnx.TensorProto.int32Data: array expected"); + message.int32Data = []; + for (var i = 0; i < object.int32Data.length; ++i) + message.int32Data[i] = object.int32Data[i] | 0; + } + if (object.stringData) { + if (!Array.isArray(object.stringData)) + throw TypeError(".onnx.TensorProto.stringData: array expected"); + message.stringData = []; + for (var i = 0; i < object.stringData.length; ++i) + if (typeof object.stringData[i] === "string") + $util.base64.decode(object.stringData[i], message.stringData[i] = $util.newBuffer($util.base64.length(object.stringData[i])), 0); + else if (object.stringData[i].length >= 0) + message.stringData[i] = object.stringData[i]; + } + if (object.int64Data) { + if (!Array.isArray(object.int64Data)) + throw TypeError(".onnx.TensorProto.int64Data: array expected"); + message.int64Data = []; + for (var i = 0; i < object.int64Data.length; ++i) + if ($util.Long) + (message.int64Data[i] = $util.Long.fromValue(object.int64Data[i])).unsigned = false; + else if (typeof object.int64Data[i] === "string") + message.int64Data[i] = parseInt(object.int64Data[i], 10); + else if (typeof object.int64Data[i] === "number") + message.int64Data[i] = object.int64Data[i]; + else if (typeof object.int64Data[i] === "object") + message.int64Data[i] = new $util.LongBits(object.int64Data[i].low >>> 0, object.int64Data[i].high >>> 0).toNumber(); + } + if (object.name != null) + message.name = String(object.name); + if (object.docString != null) + message.docString = String(object.docString); + if (object.rawData != null) + if (typeof object.rawData === "string") + $util.base64.decode(object.rawData, message.rawData = $util.newBuffer($util.base64.length(object.rawData)), 0); + else if (object.rawData.length >= 0) + message.rawData = object.rawData; + if (object.externalData) { + if (!Array.isArray(object.externalData)) + throw TypeError(".onnx.TensorProto.externalData: array expected"); + message.externalData = []; + for (var i = 0; i < object.externalData.length; ++i) { + if (typeof object.externalData[i] !== "object") + throw TypeError(".onnx.TensorProto.externalData: object expected"); + message.externalData[i] = $root.onnx.StringStringEntryProto.fromObject(object.externalData[i]); + } + } + switch (object.dataLocation) { + default: + if (typeof object.dataLocation === "number") { + message.dataLocation = object.dataLocation; + break; + } + break; + case "DEFAULT": + case 0: + message.dataLocation = 0; + break; + case "EXTERNAL": + case 1: + message.dataLocation = 1; + break; + } + if (object.doubleData) { + if (!Array.isArray(object.doubleData)) + throw TypeError(".onnx.TensorProto.doubleData: array expected"); + message.doubleData = []; + for (var i = 0; i < object.doubleData.length; ++i) + message.doubleData[i] = Number(object.doubleData[i]); + } + if (object.uint64Data) { + if (!Array.isArray(object.uint64Data)) + throw TypeError(".onnx.TensorProto.uint64Data: array expected"); + message.uint64Data = []; + for (var i = 0; i < object.uint64Data.length; ++i) + if ($util.Long) + (message.uint64Data[i] = $util.Long.fromValue(object.uint64Data[i])).unsigned = true; + else if (typeof object.uint64Data[i] === "string") + message.uint64Data[i] = parseInt(object.uint64Data[i], 10); + else if (typeof object.uint64Data[i] === "number") + message.uint64Data[i] = object.uint64Data[i]; + else if (typeof object.uint64Data[i] === "object") + message.uint64Data[i] = new $util.LongBits(object.uint64Data[i].low >>> 0, object.uint64Data[i].high >>> 0).toNumber(true); + } + return message; + }; + + /** + * Creates a plain object from a TensorProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorProto + * @static + * @param {onnx.TensorProto} message TensorProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.dims = []; + object.floatData = []; + object.int32Data = []; + object.stringData = []; + object.int64Data = []; + object.doubleData = []; + object.uint64Data = []; + object.externalData = []; + } + if (options.defaults) { + object.dataType = 0; + object.segment = null; + object.name = ""; + if (options.bytes === String) + object.rawData = ""; + else { + object.rawData = []; + if (options.bytes !== Array) + object.rawData = $util.newBuffer(object.rawData); + } + object.docString = ""; + object.dataLocation = options.enums === String ? "DEFAULT" : 0; + } + if (message.dims && message.dims.length) { + object.dims = []; + for (var j = 0; j < message.dims.length; ++j) + if (typeof message.dims[j] === "number") + object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; + else + object.dims[j] = options.longs === String ? $util.Long.prototype.toString.call(message.dims[j]) : options.longs === Number ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() : message.dims[j]; + } + if (message.dataType != null && message.hasOwnProperty("dataType")) + object.dataType = message.dataType; + if (message.segment != null && message.hasOwnProperty("segment")) + object.segment = $root.onnx.TensorProto.Segment.toObject(message.segment, options); + if (message.floatData && message.floatData.length) { + object.floatData = []; + for (var j = 0; j < message.floatData.length; ++j) + object.floatData[j] = options.json && !isFinite(message.floatData[j]) ? String(message.floatData[j]) : message.floatData[j]; + } + if (message.int32Data && message.int32Data.length) { + object.int32Data = []; + for (var j = 0; j < message.int32Data.length; ++j) + object.int32Data[j] = message.int32Data[j]; + } + if (message.stringData && message.stringData.length) { + object.stringData = []; + for (var j = 0; j < message.stringData.length; ++j) + object.stringData[j] = options.bytes === String ? $util.base64.encode(message.stringData[j], 0, message.stringData[j].length) : options.bytes === Array ? Array.prototype.slice.call(message.stringData[j]) : message.stringData[j]; + } + if (message.int64Data && message.int64Data.length) { + object.int64Data = []; + for (var j = 0; j < message.int64Data.length; ++j) + if (typeof message.int64Data[j] === "number") + object.int64Data[j] = options.longs === String ? String(message.int64Data[j]) : message.int64Data[j]; + else + object.int64Data[j] = options.longs === String ? $util.Long.prototype.toString.call(message.int64Data[j]) : options.longs === Number ? new $util.LongBits(message.int64Data[j].low >>> 0, message.int64Data[j].high >>> 0).toNumber() : message.int64Data[j]; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.rawData != null && message.hasOwnProperty("rawData")) + object.rawData = options.bytes === String ? $util.base64.encode(message.rawData, 0, message.rawData.length) : options.bytes === Array ? Array.prototype.slice.call(message.rawData) : message.rawData; + if (message.doubleData && message.doubleData.length) { + object.doubleData = []; + for (var j = 0; j < message.doubleData.length; ++j) + object.doubleData[j] = options.json && !isFinite(message.doubleData[j]) ? String(message.doubleData[j]) : message.doubleData[j]; + } + if (message.uint64Data && message.uint64Data.length) { + object.uint64Data = []; + for (var j = 0; j < message.uint64Data.length; ++j) + if (typeof message.uint64Data[j] === "number") + object.uint64Data[j] = options.longs === String ? String(message.uint64Data[j]) : message.uint64Data[j]; + else + object.uint64Data[j] = options.longs === String ? $util.Long.prototype.toString.call(message.uint64Data[j]) : options.longs === Number ? new $util.LongBits(message.uint64Data[j].low >>> 0, message.uint64Data[j].high >>> 0).toNumber(true) : message.uint64Data[j]; + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.externalData && message.externalData.length) { + object.externalData = []; + for (var j = 0; j < message.externalData.length; ++j) + object.externalData[j] = $root.onnx.StringStringEntryProto.toObject(message.externalData[j], options); + } + if (message.dataLocation != null && message.hasOwnProperty("dataLocation")) + object.dataLocation = options.enums === String ? $root.onnx.TensorProto.DataLocation[message.dataLocation] === undefined ? message.dataLocation : $root.onnx.TensorProto.DataLocation[message.dataLocation] : message.dataLocation; + return object; + }; + + /** + * Converts this TensorProto to JSON. + * @function toJSON + * @memberof onnx.TensorProto + * @instance + * @returns {Object.} JSON object + */ + TensorProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorProto + * @function getTypeUrl + * @memberof onnx.TensorProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorProto"; + }; + + /** + * DataType enum. + * @name onnx.TensorProto.DataType + * @enum {number} + * @property {number} UNDEFINED=0 UNDEFINED value + * @property {number} FLOAT=1 FLOAT value + * @property {number} UINT8=2 UINT8 value + * @property {number} INT8=3 INT8 value + * @property {number} UINT16=4 UINT16 value + * @property {number} INT16=5 INT16 value + * @property {number} INT32=6 INT32 value + * @property {number} INT64=7 INT64 value + * @property {number} STRING=8 STRING value + * @property {number} BOOL=9 BOOL value + * @property {number} FLOAT16=10 FLOAT16 value + * @property {number} DOUBLE=11 DOUBLE value + * @property {number} UINT32=12 UINT32 value + * @property {number} UINT64=13 UINT64 value + * @property {number} COMPLEX64=14 COMPLEX64 value + * @property {number} COMPLEX128=15 COMPLEX128 value + * @property {number} BFLOAT16=16 BFLOAT16 value + * @property {number} FLOAT8E4M3FN=17 FLOAT8E4M3FN value + * @property {number} FLOAT8E4M3FNUZ=18 FLOAT8E4M3FNUZ value + * @property {number} FLOAT8E5M2=19 FLOAT8E5M2 value + * @property {number} FLOAT8E5M2FNUZ=20 FLOAT8E5M2FNUZ value + */ + TensorProto.DataType = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "UNDEFINED"] = 0; + values[valuesById[1] = "FLOAT"] = 1; + values[valuesById[2] = "UINT8"] = 2; + values[valuesById[3] = "INT8"] = 3; + values[valuesById[4] = "UINT16"] = 4; + values[valuesById[5] = "INT16"] = 5; + values[valuesById[6] = "INT32"] = 6; + values[valuesById[7] = "INT64"] = 7; + values[valuesById[8] = "STRING"] = 8; + values[valuesById[9] = "BOOL"] = 9; + values[valuesById[10] = "FLOAT16"] = 10; + values[valuesById[11] = "DOUBLE"] = 11; + values[valuesById[12] = "UINT32"] = 12; + values[valuesById[13] = "UINT64"] = 13; + values[valuesById[14] = "COMPLEX64"] = 14; + values[valuesById[15] = "COMPLEX128"] = 15; + values[valuesById[16] = "BFLOAT16"] = 16; + values[valuesById[17] = "FLOAT8E4M3FN"] = 17; + values[valuesById[18] = "FLOAT8E4M3FNUZ"] = 18; + values[valuesById[19] = "FLOAT8E5M2"] = 19; + values[valuesById[20] = "FLOAT8E5M2FNUZ"] = 20; + return values; + })(); + + TensorProto.Segment = (function() { + + /** + * Properties of a Segment. + * @memberof onnx.TensorProto + * @interface ISegment + * @property {number|Long|null} [begin] Segment begin + * @property {number|Long|null} [end] Segment end + */ + + /** + * Constructs a new Segment. + * @memberof onnx.TensorProto + * @classdesc Represents a Segment. + * @implements ISegment + * @constructor + * @param {onnx.TensorProto.ISegment=} [properties] Properties to set + */ + function Segment(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Segment begin. + * @member {number|Long} begin + * @memberof onnx.TensorProto.Segment + * @instance + */ + Segment.prototype.begin = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * Segment end. + * @member {number|Long} end + * @memberof onnx.TensorProto.Segment + * @instance + */ + Segment.prototype.end = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * Creates a new Segment instance using the specified properties. + * @function create + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment=} [properties] Properties to set + * @returns {onnx.TensorProto.Segment} Segment instance + */ + Segment.create = function create(properties) { + return new Segment(properties); + }; + + /** + * Encodes the specified Segment message. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. + * @function encode + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Segment.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.begin != null && Object.hasOwnProperty.call(message, "begin")) + writer.uint32(/* id 1, wireType 0 =*/8).int64(message.begin); + if (message.end != null && Object.hasOwnProperty.call(message, "end")) + writer.uint32(/* id 2, wireType 0 =*/16).int64(message.end); + return writer; + }; + + /** + * Encodes the specified Segment message, length delimited. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Segment.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Segment message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorProto.Segment + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorProto.Segment} Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Segment.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorProto.Segment(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.begin = reader.int64(); + break; + } + case 2: { + message.end = reader.int64(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Segment message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorProto.Segment + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorProto.Segment} Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Segment.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Segment message. + * @function verify + * @memberof onnx.TensorProto.Segment + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Segment.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.begin != null && message.hasOwnProperty("begin")) + if (!$util.isInteger(message.begin) && !(message.begin && $util.isInteger(message.begin.low) && $util.isInteger(message.begin.high))) + return "begin: integer|Long expected"; + if (message.end != null && message.hasOwnProperty("end")) + if (!$util.isInteger(message.end) && !(message.end && $util.isInteger(message.end.low) && $util.isInteger(message.end.high))) + return "end: integer|Long expected"; + return null; + }; + + /** + * Creates a Segment message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorProto.Segment + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorProto.Segment} Segment + */ + Segment.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorProto.Segment) + return object; + var message = new $root.onnx.TensorProto.Segment(); + if (object.begin != null) + if ($util.Long) + (message.begin = $util.Long.fromValue(object.begin)).unsigned = false; + else if (typeof object.begin === "string") + message.begin = parseInt(object.begin, 10); + else if (typeof object.begin === "number") + message.begin = object.begin; + else if (typeof object.begin === "object") + message.begin = new $util.LongBits(object.begin.low >>> 0, object.begin.high >>> 0).toNumber(); + if (object.end != null) + if ($util.Long) + (message.end = $util.Long.fromValue(object.end)).unsigned = false; + else if (typeof object.end === "string") + message.end = parseInt(object.end, 10); + else if (typeof object.end === "number") + message.end = object.end; + else if (typeof object.end === "object") + message.end = new $util.LongBits(object.end.low >>> 0, object.end.high >>> 0).toNumber(); + return message; + }; + + /** + * Creates a plain object from a Segment message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.Segment} message Segment + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Segment.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.begin = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.begin = options.longs === String ? "0" : 0; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.end = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.end = options.longs === String ? "0" : 0; + } + if (message.begin != null && message.hasOwnProperty("begin")) + if (typeof message.begin === "number") + object.begin = options.longs === String ? String(message.begin) : message.begin; + else + object.begin = options.longs === String ? $util.Long.prototype.toString.call(message.begin) : options.longs === Number ? new $util.LongBits(message.begin.low >>> 0, message.begin.high >>> 0).toNumber() : message.begin; + if (message.end != null && message.hasOwnProperty("end")) + if (typeof message.end === "number") + object.end = options.longs === String ? String(message.end) : message.end; + else + object.end = options.longs === String ? $util.Long.prototype.toString.call(message.end) : options.longs === Number ? new $util.LongBits(message.end.low >>> 0, message.end.high >>> 0).toNumber() : message.end; + return object; + }; + + /** + * Converts this Segment to JSON. + * @function toJSON + * @memberof onnx.TensorProto.Segment + * @instance + * @returns {Object.} JSON object + */ + Segment.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Segment + * @function getTypeUrl + * @memberof onnx.TensorProto.Segment + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Segment.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorProto.Segment"; + }; + + return Segment; + })(); + + /** + * DataLocation enum. + * @name onnx.TensorProto.DataLocation + * @enum {number} + * @property {number} DEFAULT=0 DEFAULT value + * @property {number} EXTERNAL=1 EXTERNAL value + */ + TensorProto.DataLocation = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "DEFAULT"] = 0; + values[valuesById[1] = "EXTERNAL"] = 1; + return values; + })(); + + return TensorProto; + })(); + + onnx.SparseTensorProto = (function() { + + /** + * Properties of a SparseTensorProto. + * @memberof onnx + * @interface ISparseTensorProto + * @property {onnx.ITensorProto|null} [values] SparseTensorProto values + * @property {onnx.ITensorProto|null} [indices] SparseTensorProto indices + * @property {Array.|null} [dims] SparseTensorProto dims + */ + + /** + * Constructs a new SparseTensorProto. + * @memberof onnx + * @classdesc Represents a SparseTensorProto. + * @implements ISparseTensorProto + * @constructor + * @param {onnx.ISparseTensorProto=} [properties] Properties to set + */ + function SparseTensorProto(properties) { + this.dims = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * SparseTensorProto values. + * @member {onnx.ITensorProto|null|undefined} values + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.values = null; + + /** + * SparseTensorProto indices. + * @member {onnx.ITensorProto|null|undefined} indices + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.indices = null; + + /** + * SparseTensorProto dims. + * @member {Array.} dims + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.dims = $util.emptyArray; + + /** + * Creates a new SparseTensorProto instance using the specified properties. + * @function create + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto=} [properties] Properties to set + * @returns {onnx.SparseTensorProto} SparseTensorProto instance + */ + SparseTensorProto.create = function create(properties) { + return new SparseTensorProto(properties); + }; + + /** + * Encodes the specified SparseTensorProto message. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. + * @function encode + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensorProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.values != null && Object.hasOwnProperty.call(message, "values")) + $root.onnx.TensorProto.encode(message.values, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.indices != null && Object.hasOwnProperty.call(message, "indices")) + $root.onnx.TensorProto.encode(message.indices, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.dims != null && message.dims.length) { + writer.uint32(/* id 3, wireType 2 =*/26).fork(); + for (var i = 0; i < message.dims.length; ++i) + writer.int64(message.dims[i]); + writer.ldelim(); + } + return writer; + }; + + /** + * Encodes the specified SparseTensorProto message, length delimited. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensorProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.SparseTensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.SparseTensorProto} SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensorProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.SparseTensorProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.values = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 2: { + message.indices = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 3: { + if (!(message.dims && message.dims.length)) + message.dims = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.dims.push(reader.int64()); + } else + message.dims.push(reader.int64()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.SparseTensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.SparseTensorProto} SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensorProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a SparseTensorProto message. + * @function verify + * @memberof onnx.SparseTensorProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + SparseTensorProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.values != null && message.hasOwnProperty("values")) { + var error = $root.onnx.TensorProto.verify(message.values); + if (error) + return "values." + error; + } + if (message.indices != null && message.hasOwnProperty("indices")) { + var error = $root.onnx.TensorProto.verify(message.indices); + if (error) + return "indices." + error; + } + if (message.dims != null && message.hasOwnProperty("dims")) { + if (!Array.isArray(message.dims)) + return "dims: array expected"; + for (var i = 0; i < message.dims.length; ++i) + if (!$util.isInteger(message.dims[i]) && !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high))) + return "dims: integer|Long[] expected"; + } + return null; + }; + + /** + * Creates a SparseTensorProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.SparseTensorProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.SparseTensorProto} SparseTensorProto + */ + SparseTensorProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.SparseTensorProto) + return object; + var message = new $root.onnx.SparseTensorProto(); + if (object.values != null) { + if (typeof object.values !== "object") + throw TypeError(".onnx.SparseTensorProto.values: object expected"); + message.values = $root.onnx.TensorProto.fromObject(object.values); + } + if (object.indices != null) { + if (typeof object.indices !== "object") + throw TypeError(".onnx.SparseTensorProto.indices: object expected"); + message.indices = $root.onnx.TensorProto.fromObject(object.indices); + } + if (object.dims) { + if (!Array.isArray(object.dims)) + throw TypeError(".onnx.SparseTensorProto.dims: array expected"); + message.dims = []; + for (var i = 0; i < object.dims.length; ++i) + if ($util.Long) + (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; + else if (typeof object.dims[i] === "string") + message.dims[i] = parseInt(object.dims[i], 10); + else if (typeof object.dims[i] === "number") + message.dims[i] = object.dims[i]; + else if (typeof object.dims[i] === "object") + message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); + } + return message; + }; + + /** + * Creates a plain object from a SparseTensorProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.SparseTensorProto} message SparseTensorProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + SparseTensorProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) + object.dims = []; + if (options.defaults) { + object.values = null; + object.indices = null; + } + if (message.values != null && message.hasOwnProperty("values")) + object.values = $root.onnx.TensorProto.toObject(message.values, options); + if (message.indices != null && message.hasOwnProperty("indices")) + object.indices = $root.onnx.TensorProto.toObject(message.indices, options); + if (message.dims && message.dims.length) { + object.dims = []; + for (var j = 0; j < message.dims.length; ++j) + if (typeof message.dims[j] === "number") + object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; + else + object.dims[j] = options.longs === String ? $util.Long.prototype.toString.call(message.dims[j]) : options.longs === Number ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() : message.dims[j]; + } + return object; + }; + + /** + * Converts this SparseTensorProto to JSON. + * @function toJSON + * @memberof onnx.SparseTensorProto + * @instance + * @returns {Object.} JSON object + */ + SparseTensorProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for SparseTensorProto + * @function getTypeUrl + * @memberof onnx.SparseTensorProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + SparseTensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.SparseTensorProto"; + }; + + return SparseTensorProto; + })(); + + onnx.TensorShapeProto = (function() { + + /** + * Properties of a TensorShapeProto. + * @memberof onnx + * @interface ITensorShapeProto + * @property {Array.|null} [dim] TensorShapeProto dim + */ + + /** + * Constructs a new TensorShapeProto. + * @memberof onnx + * @classdesc Represents a TensorShapeProto. + * @implements ITensorShapeProto + * @constructor + * @param {onnx.ITensorShapeProto=} [properties] Properties to set + */ + function TensorShapeProto(properties) { + this.dim = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorShapeProto dim. + * @member {Array.} dim + * @memberof onnx.TensorShapeProto + * @instance + */ + TensorShapeProto.prototype.dim = $util.emptyArray; + + /** + * Creates a new TensorShapeProto instance using the specified properties. + * @function create + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto=} [properties] Properties to set + * @returns {onnx.TensorShapeProto} TensorShapeProto instance + */ + TensorShapeProto.create = function create(properties) { + return new TensorShapeProto(properties); + }; + + /** + * Encodes the specified TensorShapeProto message. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. + * @function encode + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorShapeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.dim != null && message.dim.length) + for (var i = 0; i < message.dim.length; ++i) + $root.onnx.TensorShapeProto.Dimension.encode(message.dim[i], writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TensorShapeProto message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorShapeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorShapeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorShapeProto} TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorShapeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorShapeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.dim && message.dim.length)) + message.dim = []; + message.dim.push($root.onnx.TensorShapeProto.Dimension.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorShapeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorShapeProto} TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorShapeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorShapeProto message. + * @function verify + * @memberof onnx.TensorShapeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorShapeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.dim != null && message.hasOwnProperty("dim")) { + if (!Array.isArray(message.dim)) + return "dim: array expected"; + for (var i = 0; i < message.dim.length; ++i) { + var error = $root.onnx.TensorShapeProto.Dimension.verify(message.dim[i]); + if (error) + return "dim." + error; + } + } + return null; + }; + + /** + * Creates a TensorShapeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorShapeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorShapeProto} TensorShapeProto + */ + TensorShapeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorShapeProto) + return object; + var message = new $root.onnx.TensorShapeProto(); + if (object.dim) { + if (!Array.isArray(object.dim)) + throw TypeError(".onnx.TensorShapeProto.dim: array expected"); + message.dim = []; + for (var i = 0; i < object.dim.length; ++i) { + if (typeof object.dim[i] !== "object") + throw TypeError(".onnx.TensorShapeProto.dim: object expected"); + message.dim[i] = $root.onnx.TensorShapeProto.Dimension.fromObject(object.dim[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TensorShapeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.TensorShapeProto} message TensorShapeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorShapeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) + object.dim = []; + if (message.dim && message.dim.length) { + object.dim = []; + for (var j = 0; j < message.dim.length; ++j) + object.dim[j] = $root.onnx.TensorShapeProto.Dimension.toObject(message.dim[j], options); + } + return object; + }; + + /** + * Converts this TensorShapeProto to JSON. + * @function toJSON + * @memberof onnx.TensorShapeProto + * @instance + * @returns {Object.} JSON object + */ + TensorShapeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorShapeProto + * @function getTypeUrl + * @memberof onnx.TensorShapeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorShapeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorShapeProto"; + }; + + TensorShapeProto.Dimension = (function() { + + /** + * Properties of a Dimension. + * @memberof onnx.TensorShapeProto + * @interface IDimension + * @property {number|Long|null} [dimValue] Dimension dimValue + * @property {string|null} [dimParam] Dimension dimParam + * @property {string|null} [denotation] Dimension denotation + */ + + /** + * Constructs a new Dimension. + * @memberof onnx.TensorShapeProto + * @classdesc Represents a Dimension. + * @implements IDimension + * @constructor + * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set + */ + function Dimension(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Dimension dimValue. + * @member {number|Long|null|undefined} dimValue + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.dimValue = null; + + /** + * Dimension dimParam. + * @member {string|null|undefined} dimParam + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.dimParam = null; + + /** + * Dimension denotation. + * @member {string} denotation + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.denotation = ""; + + // OneOf field names bound to virtual getters and setters + var $oneOfFields; + + /** + * Dimension value. + * @member {"dimValue"|"dimParam"|undefined} value + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Object.defineProperty(Dimension.prototype, "value", { + get: $util.oneOfGetter($oneOfFields = ["dimValue", "dimParam"]), + set: $util.oneOfSetter($oneOfFields) + }); + + /** + * Creates a new Dimension instance using the specified properties. + * @function create + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set + * @returns {onnx.TensorShapeProto.Dimension} Dimension instance + */ + Dimension.create = function create(properties) { + return new Dimension(properties); + }; + + /** + * Encodes the specified Dimension message. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @function encode + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Dimension.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.dimValue != null && Object.hasOwnProperty.call(message, "dimValue")) + writer.uint32(/* id 1, wireType 0 =*/8).int64(message.dimValue); + if (message.dimParam != null && Object.hasOwnProperty.call(message, "dimParam")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.dimParam); + if (message.denotation != null && Object.hasOwnProperty.call(message, "denotation")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.denotation); + return writer; + }; + + /** + * Encodes the specified Dimension message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Dimension.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Dimension message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorShapeProto.Dimension} Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Dimension.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorShapeProto.Dimension(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.dimValue = reader.int64(); + break; + } + case 2: { + message.dimParam = reader.string(); + break; + } + case 3: { + message.denotation = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Dimension message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorShapeProto.Dimension} Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Dimension.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Dimension message. + * @function verify + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Dimension.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + var properties = {}; + if (message.dimValue != null && message.hasOwnProperty("dimValue")) { + properties.value = 1; + if (!$util.isInteger(message.dimValue) && !(message.dimValue && $util.isInteger(message.dimValue.low) && $util.isInteger(message.dimValue.high))) + return "dimValue: integer|Long expected"; + } + if (message.dimParam != null && message.hasOwnProperty("dimParam")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + if (!$util.isString(message.dimParam)) + return "dimParam: string expected"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + if (!$util.isString(message.denotation)) + return "denotation: string expected"; + return null; + }; + + /** + * Creates a Dimension message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorShapeProto.Dimension} Dimension + */ + Dimension.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorShapeProto.Dimension) + return object; + var message = new $root.onnx.TensorShapeProto.Dimension(); + if (object.dimValue != null) + if ($util.Long) + (message.dimValue = $util.Long.fromValue(object.dimValue)).unsigned = false; + else if (typeof object.dimValue === "string") + message.dimValue = parseInt(object.dimValue, 10); + else if (typeof object.dimValue === "number") + message.dimValue = object.dimValue; + else if (typeof object.dimValue === "object") + message.dimValue = new $util.LongBits(object.dimValue.low >>> 0, object.dimValue.high >>> 0).toNumber(); + if (object.dimParam != null) + message.dimParam = String(object.dimParam); + if (object.denotation != null) + message.denotation = String(object.denotation); + return message; + }; + + /** + * Creates a plain object from a Dimension message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.Dimension} message Dimension + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Dimension.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.denotation = ""; + if (message.dimValue != null && message.hasOwnProperty("dimValue")) { + if (typeof message.dimValue === "number") + object.dimValue = options.longs === String ? String(message.dimValue) : message.dimValue; + else + object.dimValue = options.longs === String ? $util.Long.prototype.toString.call(message.dimValue) : options.longs === Number ? new $util.LongBits(message.dimValue.low >>> 0, message.dimValue.high >>> 0).toNumber() : message.dimValue; + if (options.oneofs) + object.value = "dimValue"; + } + if (message.dimParam != null && message.hasOwnProperty("dimParam")) { + object.dimParam = message.dimParam; + if (options.oneofs) + object.value = "dimParam"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + object.denotation = message.denotation; + return object; + }; + + /** + * Converts this Dimension to JSON. + * @function toJSON + * @memberof onnx.TensorShapeProto.Dimension + * @instance + * @returns {Object.} JSON object + */ + Dimension.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Dimension + * @function getTypeUrl + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Dimension.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorShapeProto.Dimension"; + }; + + return Dimension; + })(); + + return TensorShapeProto; + })(); + + onnx.TypeProto = (function() { + + /** + * Properties of a TypeProto. + * @memberof onnx + * @interface ITypeProto + * @property {onnx.TypeProto.ITensor|null} [tensorType] TypeProto tensorType + * @property {onnx.TypeProto.ISequence|null} [sequenceType] TypeProto sequenceType + * @property {onnx.TypeProto.IMap|null} [mapType] TypeProto mapType + * @property {onnx.TypeProto.IOptional|null} [optionalType] TypeProto optionalType + * @property {onnx.TypeProto.ISparseTensor|null} [sparseTensorType] TypeProto sparseTensorType + * @property {string|null} [denotation] TypeProto denotation + */ + + /** + * Constructs a new TypeProto. + * @memberof onnx + * @classdesc Represents a TypeProto. + * @implements ITypeProto + * @constructor + * @param {onnx.ITypeProto=} [properties] Properties to set + */ + function TypeProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TypeProto tensorType. + * @member {onnx.TypeProto.ITensor|null|undefined} tensorType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.tensorType = null; + + /** + * TypeProto sequenceType. + * @member {onnx.TypeProto.ISequence|null|undefined} sequenceType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.sequenceType = null; + + /** + * TypeProto mapType. + * @member {onnx.TypeProto.IMap|null|undefined} mapType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.mapType = null; + + /** + * TypeProto optionalType. + * @member {onnx.TypeProto.IOptional|null|undefined} optionalType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.optionalType = null; + + /** + * TypeProto sparseTensorType. + * @member {onnx.TypeProto.ISparseTensor|null|undefined} sparseTensorType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.sparseTensorType = null; + + /** + * TypeProto denotation. + * @member {string} denotation + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.denotation = ""; + + // OneOf field names bound to virtual getters and setters + var $oneOfFields; + + /** + * TypeProto value. + * @member {"tensorType"|"sequenceType"|"mapType"|"optionalType"|"sparseTensorType"|undefined} value + * @memberof onnx.TypeProto + * @instance + */ + Object.defineProperty(TypeProto.prototype, "value", { + get: $util.oneOfGetter($oneOfFields = ["tensorType", "sequenceType", "mapType", "optionalType", "sparseTensorType"]), + set: $util.oneOfSetter($oneOfFields) + }); + + /** + * Creates a new TypeProto instance using the specified properties. + * @function create + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto=} [properties] Properties to set + * @returns {onnx.TypeProto} TypeProto instance + */ + TypeProto.create = function create(properties) { + return new TypeProto(properties); + }; + + /** + * Encodes the specified TypeProto message. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto} message TypeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TypeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.tensorType != null && Object.hasOwnProperty.call(message, "tensorType")) + $root.onnx.TypeProto.Tensor.encode(message.tensorType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.sequenceType != null && Object.hasOwnProperty.call(message, "sequenceType")) + $root.onnx.TypeProto.Sequence.encode(message.sequenceType, writer.uint32(/* id 4, wireType 2 =*/34).fork()).ldelim(); + if (message.mapType != null && Object.hasOwnProperty.call(message, "mapType")) + $root.onnx.TypeProto.Map.encode(message.mapType, writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.denotation != null && Object.hasOwnProperty.call(message, "denotation")) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.denotation); + if (message.sparseTensorType != null && Object.hasOwnProperty.call(message, "sparseTensorType")) + $root.onnx.TypeProto.SparseTensor.encode(message.sparseTensorType, writer.uint32(/* id 8, wireType 2 =*/66).fork()).ldelim(); + if (message.optionalType != null && Object.hasOwnProperty.call(message, "optionalType")) + $root.onnx.TypeProto.Optional.encode(message.optionalType, writer.uint32(/* id 9, wireType 2 =*/74).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TypeProto message, length delimited. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto} message TypeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TypeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TypeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto} TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TypeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.tensorType = $root.onnx.TypeProto.Tensor.decode(reader, reader.uint32()); + break; + } + case 4: { + message.sequenceType = $root.onnx.TypeProto.Sequence.decode(reader, reader.uint32()); + break; + } + case 5: { + message.mapType = $root.onnx.TypeProto.Map.decode(reader, reader.uint32()); + break; + } + case 9: { + message.optionalType = $root.onnx.TypeProto.Optional.decode(reader, reader.uint32()); + break; + } + case 8: { + message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.decode(reader, reader.uint32()); + break; + } + case 6: { + message.denotation = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TypeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto} TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TypeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TypeProto message. + * @function verify + * @memberof onnx.TypeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TypeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + var properties = {}; + if (message.tensorType != null && message.hasOwnProperty("tensorType")) { + properties.value = 1; + { + var error = $root.onnx.TypeProto.Tensor.verify(message.tensorType); + if (error) + return "tensorType." + error; + } + } + if (message.sequenceType != null && message.hasOwnProperty("sequenceType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Sequence.verify(message.sequenceType); + if (error) + return "sequenceType." + error; + } + } + if (message.mapType != null && message.hasOwnProperty("mapType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Map.verify(message.mapType); + if (error) + return "mapType." + error; + } + } + if (message.optionalType != null && message.hasOwnProperty("optionalType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Optional.verify(message.optionalType); + if (error) + return "optionalType." + error; + } + } + if (message.sparseTensorType != null && message.hasOwnProperty("sparseTensorType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.SparseTensor.verify(message.sparseTensorType); + if (error) + return "sparseTensorType." + error; + } + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + if (!$util.isString(message.denotation)) + return "denotation: string expected"; + return null; + }; + + /** + * Creates a TypeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto} TypeProto + */ + TypeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto) + return object; + var message = new $root.onnx.TypeProto(); + if (object.tensorType != null) { + if (typeof object.tensorType !== "object") + throw TypeError(".onnx.TypeProto.tensorType: object expected"); + message.tensorType = $root.onnx.TypeProto.Tensor.fromObject(object.tensorType); + } + if (object.sequenceType != null) { + if (typeof object.sequenceType !== "object") + throw TypeError(".onnx.TypeProto.sequenceType: object expected"); + message.sequenceType = $root.onnx.TypeProto.Sequence.fromObject(object.sequenceType); + } + if (object.mapType != null) { + if (typeof object.mapType !== "object") + throw TypeError(".onnx.TypeProto.mapType: object expected"); + message.mapType = $root.onnx.TypeProto.Map.fromObject(object.mapType); + } + if (object.optionalType != null) { + if (typeof object.optionalType !== "object") + throw TypeError(".onnx.TypeProto.optionalType: object expected"); + message.optionalType = $root.onnx.TypeProto.Optional.fromObject(object.optionalType); + } + if (object.sparseTensorType != null) { + if (typeof object.sparseTensorType !== "object") + throw TypeError(".onnx.TypeProto.sparseTensorType: object expected"); + message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.fromObject(object.sparseTensorType); + } + if (object.denotation != null) + message.denotation = String(object.denotation); + return message; + }; + + /** + * Creates a plain object from a TypeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto + * @static + * @param {onnx.TypeProto} message TypeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TypeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.denotation = ""; + if (message.tensorType != null && message.hasOwnProperty("tensorType")) { + object.tensorType = $root.onnx.TypeProto.Tensor.toObject(message.tensorType, options); + if (options.oneofs) + object.value = "tensorType"; + } + if (message.sequenceType != null && message.hasOwnProperty("sequenceType")) { + object.sequenceType = $root.onnx.TypeProto.Sequence.toObject(message.sequenceType, options); + if (options.oneofs) + object.value = "sequenceType"; + } + if (message.mapType != null && message.hasOwnProperty("mapType")) { + object.mapType = $root.onnx.TypeProto.Map.toObject(message.mapType, options); + if (options.oneofs) + object.value = "mapType"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + object.denotation = message.denotation; + if (message.sparseTensorType != null && message.hasOwnProperty("sparseTensorType")) { + object.sparseTensorType = $root.onnx.TypeProto.SparseTensor.toObject(message.sparseTensorType, options); + if (options.oneofs) + object.value = "sparseTensorType"; + } + if (message.optionalType != null && message.hasOwnProperty("optionalType")) { + object.optionalType = $root.onnx.TypeProto.Optional.toObject(message.optionalType, options); + if (options.oneofs) + object.value = "optionalType"; + } + return object; + }; + + /** + * Converts this TypeProto to JSON. + * @function toJSON + * @memberof onnx.TypeProto + * @instance + * @returns {Object.} JSON object + */ + TypeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TypeProto + * @function getTypeUrl + * @memberof onnx.TypeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TypeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto"; + }; + + TypeProto.Tensor = (function() { + + /** + * Properties of a Tensor. + * @memberof onnx.TypeProto + * @interface ITensor + * @property {number|null} [elemType] Tensor elemType + * @property {onnx.ITensorShapeProto|null} [shape] Tensor shape + */ + + /** + * Constructs a new Tensor. + * @memberof onnx.TypeProto + * @classdesc Represents a Tensor. + * @implements ITensor + * @constructor + * @param {onnx.TypeProto.ITensor=} [properties] Properties to set + */ + function Tensor(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Tensor elemType. + * @member {number} elemType + * @memberof onnx.TypeProto.Tensor + * @instance + */ + Tensor.prototype.elemType = 0; + + /** + * Tensor shape. + * @member {onnx.ITensorShapeProto|null|undefined} shape + * @memberof onnx.TypeProto.Tensor + * @instance + */ + Tensor.prototype.shape = null; + + /** + * Creates a new Tensor instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor=} [properties] Properties to set + * @returns {onnx.TypeProto.Tensor} Tensor instance + */ + Tensor.create = function create(properties) { + return new Tensor(properties); + }; + + /** + * Encodes the specified Tensor message. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Tensor.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + writer.uint32(/* id 1, wireType 0 =*/8).int32(message.elemType); + if (message.shape != null && Object.hasOwnProperty.call(message, "shape")) + $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Tensor message, length delimited. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Tensor.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Tensor message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Tensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Tensor} Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Tensor.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Tensor(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = reader.int32(); + break; + } + case 2: { + message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Tensor message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Tensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Tensor} Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Tensor.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Tensor message. + * @function verify + * @memberof onnx.TypeProto.Tensor + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Tensor.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) + if (!$util.isInteger(message.elemType)) + return "elemType: integer expected"; + if (message.shape != null && message.hasOwnProperty("shape")) { + var error = $root.onnx.TensorShapeProto.verify(message.shape); + if (error) + return "shape." + error; + } + return null; + }; + + /** + * Creates a Tensor message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Tensor + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Tensor} Tensor + */ + Tensor.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Tensor) + return object; + var message = new $root.onnx.TypeProto.Tensor(); + if (object.elemType != null) + message.elemType = object.elemType | 0; + if (object.shape != null) { + if (typeof object.shape !== "object") + throw TypeError(".onnx.TypeProto.Tensor.shape: object expected"); + message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); + } + return message; + }; + + /** + * Creates a plain object from a Tensor message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.Tensor} message Tensor + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Tensor.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.elemType = 0; + object.shape = null; + } + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = message.elemType; + if (message.shape != null && message.hasOwnProperty("shape")) + object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); + return object; + }; + + /** + * Converts this Tensor to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Tensor + * @instance + * @returns {Object.} JSON object + */ + Tensor.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Tensor + * @function getTypeUrl + * @memberof onnx.TypeProto.Tensor + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Tensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Tensor"; + }; + + return Tensor; + })(); + + TypeProto.Sequence = (function() { + + /** + * Properties of a Sequence. + * @memberof onnx.TypeProto + * @interface ISequence + * @property {onnx.ITypeProto|null} [elemType] Sequence elemType + */ + + /** + * Constructs a new Sequence. + * @memberof onnx.TypeProto + * @classdesc Represents a Sequence. + * @implements ISequence + * @constructor + * @param {onnx.TypeProto.ISequence=} [properties] Properties to set + */ + function Sequence(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Sequence elemType. + * @member {onnx.ITypeProto|null|undefined} elemType + * @memberof onnx.TypeProto.Sequence + * @instance + */ + Sequence.prototype.elemType = null; + + /** + * Creates a new Sequence instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence=} [properties] Properties to set + * @returns {onnx.TypeProto.Sequence} Sequence instance + */ + Sequence.create = function create(properties) { + return new Sequence(properties); + }; + + /** + * Encodes the specified Sequence message. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Sequence.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Sequence message, length delimited. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Sequence.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Sequence message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Sequence + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Sequence} Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Sequence.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Sequence(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Sequence message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Sequence + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Sequence} Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Sequence.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Sequence message. + * @function verify + * @memberof onnx.TypeProto.Sequence + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Sequence.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) { + var error = $root.onnx.TypeProto.verify(message.elemType); + if (error) + return "elemType." + error; + } + return null; + }; + + /** + * Creates a Sequence message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Sequence + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Sequence} Sequence + */ + Sequence.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Sequence) + return object; + var message = new $root.onnx.TypeProto.Sequence(); + if (object.elemType != null) { + if (typeof object.elemType !== "object") + throw TypeError(".onnx.TypeProto.Sequence.elemType: object expected"); + message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); + } + return message; + }; + + /** + * Creates a plain object from a Sequence message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.Sequence} message Sequence + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Sequence.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.elemType = null; + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); + return object; + }; + + /** + * Converts this Sequence to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Sequence + * @instance + * @returns {Object.} JSON object + */ + Sequence.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Sequence + * @function getTypeUrl + * @memberof onnx.TypeProto.Sequence + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Sequence.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Sequence"; + }; + + return Sequence; + })(); + + TypeProto.Map = (function() { + + /** + * Properties of a Map. + * @memberof onnx.TypeProto + * @interface IMap + * @property {number|null} [keyType] Map keyType + * @property {onnx.ITypeProto|null} [valueType] Map valueType + */ + + /** + * Constructs a new Map. + * @memberof onnx.TypeProto + * @classdesc Represents a Map. + * @implements IMap + * @constructor + * @param {onnx.TypeProto.IMap=} [properties] Properties to set + */ + function Map(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Map keyType. + * @member {number} keyType + * @memberof onnx.TypeProto.Map + * @instance + */ + Map.prototype.keyType = 0; + + /** + * Map valueType. + * @member {onnx.ITypeProto|null|undefined} valueType + * @memberof onnx.TypeProto.Map + * @instance + */ + Map.prototype.valueType = null; + + /** + * Creates a new Map instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap=} [properties] Properties to set + * @returns {onnx.TypeProto.Map} Map instance + */ + Map.create = function create(properties) { + return new Map(properties); + }; + + /** + * Encodes the specified Map message. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap} message Map message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Map.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.keyType != null && Object.hasOwnProperty.call(message, "keyType")) + writer.uint32(/* id 1, wireType 0 =*/8).int32(message.keyType); + if (message.valueType != null && Object.hasOwnProperty.call(message, "valueType")) + $root.onnx.TypeProto.encode(message.valueType, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Map message, length delimited. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap} message Map message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Map.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Map message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Map + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Map} Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Map.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Map(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.keyType = reader.int32(); + break; + } + case 2: { + message.valueType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Map message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Map + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Map} Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Map.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Map message. + * @function verify + * @memberof onnx.TypeProto.Map + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Map.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.keyType != null && message.hasOwnProperty("keyType")) + if (!$util.isInteger(message.keyType)) + return "keyType: integer expected"; + if (message.valueType != null && message.hasOwnProperty("valueType")) { + var error = $root.onnx.TypeProto.verify(message.valueType); + if (error) + return "valueType." + error; + } + return null; + }; + + /** + * Creates a Map message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Map + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Map} Map + */ + Map.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Map) + return object; + var message = new $root.onnx.TypeProto.Map(); + if (object.keyType != null) + message.keyType = object.keyType | 0; + if (object.valueType != null) { + if (typeof object.valueType !== "object") + throw TypeError(".onnx.TypeProto.Map.valueType: object expected"); + message.valueType = $root.onnx.TypeProto.fromObject(object.valueType); + } + return message; + }; + + /** + * Creates a plain object from a Map message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.Map} message Map + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Map.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.keyType = 0; + object.valueType = null; + } + if (message.keyType != null && message.hasOwnProperty("keyType")) + object.keyType = message.keyType; + if (message.valueType != null && message.hasOwnProperty("valueType")) + object.valueType = $root.onnx.TypeProto.toObject(message.valueType, options); + return object; + }; + + /** + * Converts this Map to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Map + * @instance + * @returns {Object.} JSON object + */ + Map.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Map + * @function getTypeUrl + * @memberof onnx.TypeProto.Map + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Map.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Map"; + }; + + return Map; + })(); + + TypeProto.Optional = (function() { + + /** + * Properties of an Optional. + * @memberof onnx.TypeProto + * @interface IOptional + * @property {onnx.ITypeProto|null} [elemType] Optional elemType + */ + + /** + * Constructs a new Optional. + * @memberof onnx.TypeProto + * @classdesc Represents an Optional. + * @implements IOptional + * @constructor + * @param {onnx.TypeProto.IOptional=} [properties] Properties to set + */ + function Optional(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Optional elemType. + * @member {onnx.ITypeProto|null|undefined} elemType + * @memberof onnx.TypeProto.Optional + * @instance + */ + Optional.prototype.elemType = null; + + /** + * Creates a new Optional instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional=} [properties] Properties to set + * @returns {onnx.TypeProto.Optional} Optional instance + */ + Optional.create = function create(properties) { + return new Optional(properties); + }; + + /** + * Encodes the specified Optional message. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Optional.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Optional message, length delimited. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Optional.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an Optional message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Optional + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Optional} Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Optional.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Optional(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an Optional message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Optional + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Optional} Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Optional.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an Optional message. + * @function verify + * @memberof onnx.TypeProto.Optional + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Optional.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) { + var error = $root.onnx.TypeProto.verify(message.elemType); + if (error) + return "elemType." + error; + } + return null; + }; + + /** + * Creates an Optional message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Optional + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Optional} Optional + */ + Optional.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Optional) + return object; + var message = new $root.onnx.TypeProto.Optional(); + if (object.elemType != null) { + if (typeof object.elemType !== "object") + throw TypeError(".onnx.TypeProto.Optional.elemType: object expected"); + message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); + } + return message; + }; + + /** + * Creates a plain object from an Optional message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.Optional} message Optional + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Optional.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.elemType = null; + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); + return object; + }; + + /** + * Converts this Optional to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Optional + * @instance + * @returns {Object.} JSON object + */ + Optional.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Optional + * @function getTypeUrl + * @memberof onnx.TypeProto.Optional + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Optional.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Optional"; + }; + + return Optional; + })(); + + TypeProto.SparseTensor = (function() { + + /** + * Properties of a SparseTensor. + * @memberof onnx.TypeProto + * @interface ISparseTensor + * @property {number|null} [elemType] SparseTensor elemType + * @property {onnx.ITensorShapeProto|null} [shape] SparseTensor shape + */ + + /** + * Constructs a new SparseTensor. + * @memberof onnx.TypeProto + * @classdesc Represents a SparseTensor. + * @implements ISparseTensor + * @constructor + * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set + */ + function SparseTensor(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * SparseTensor elemType. + * @member {number} elemType + * @memberof onnx.TypeProto.SparseTensor + * @instance + */ + SparseTensor.prototype.elemType = 0; + + /** + * SparseTensor shape. + * @member {onnx.ITensorShapeProto|null|undefined} shape + * @memberof onnx.TypeProto.SparseTensor + * @instance + */ + SparseTensor.prototype.shape = null; + + /** + * Creates a new SparseTensor instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set + * @returns {onnx.TypeProto.SparseTensor} SparseTensor instance + */ + SparseTensor.create = function create(properties) { + return new SparseTensor(properties); + }; + + /** + * Encodes the specified SparseTensor message. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensor.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + writer.uint32(/* id 1, wireType 0 =*/8).int32(message.elemType); + if (message.shape != null && Object.hasOwnProperty.call(message, "shape")) + $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified SparseTensor message, length delimited. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensor.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a SparseTensor message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensor.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.SparseTensor(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = reader.int32(); + break; + } + case 2: { + message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a SparseTensor message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensor.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a SparseTensor message. + * @function verify + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + SparseTensor.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) + if (!$util.isInteger(message.elemType)) + return "elemType: integer expected"; + if (message.shape != null && message.hasOwnProperty("shape")) { + var error = $root.onnx.TensorShapeProto.verify(message.shape); + if (error) + return "shape." + error; + } + return null; + }; + + /** + * Creates a SparseTensor message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + */ + SparseTensor.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.SparseTensor) + return object; + var message = new $root.onnx.TypeProto.SparseTensor(); + if (object.elemType != null) + message.elemType = object.elemType | 0; + if (object.shape != null) { + if (typeof object.shape !== "object") + throw TypeError(".onnx.TypeProto.SparseTensor.shape: object expected"); + message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); + } + return message; + }; + + /** + * Creates a plain object from a SparseTensor message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.SparseTensor} message SparseTensor + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + SparseTensor.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.elemType = 0; + object.shape = null; + } + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = message.elemType; + if (message.shape != null && message.hasOwnProperty("shape")) + object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); + return object; + }; + + /** + * Converts this SparseTensor to JSON. + * @function toJSON + * @memberof onnx.TypeProto.SparseTensor + * @instance + * @returns {Object.} JSON object + */ + SparseTensor.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for SparseTensor + * @function getTypeUrl + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + SparseTensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.SparseTensor"; + }; + + return SparseTensor; + })(); + + return TypeProto; + })(); + + onnx.OperatorSetIdProto = (function() { + + /** + * Properties of an OperatorSetIdProto. + * @memberof onnx + * @interface IOperatorSetIdProto + * @property {string|null} [domain] OperatorSetIdProto domain + * @property {number|Long|null} [version] OperatorSetIdProto version + */ + + /** + * Constructs a new OperatorSetIdProto. + * @memberof onnx + * @classdesc Represents an OperatorSetIdProto. + * @implements IOperatorSetIdProto + * @constructor + * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set + */ + function OperatorSetIdProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * OperatorSetIdProto domain. + * @member {string} domain + * @memberof onnx.OperatorSetIdProto + * @instance + */ + OperatorSetIdProto.prototype.domain = ""; + + /** + * OperatorSetIdProto version. + * @member {number|Long} version + * @memberof onnx.OperatorSetIdProto + * @instance + */ + OperatorSetIdProto.prototype.version = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * Creates a new OperatorSetIdProto instance using the specified properties. + * @function create + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto instance + */ + OperatorSetIdProto.create = function create(properties) { + return new OperatorSetIdProto(properties); + }; + + /** + * Encodes the specified OperatorSetIdProto message. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. + * @function encode + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + OperatorSetIdProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.domain); + if (message.version != null && Object.hasOwnProperty.call(message, "version")) + writer.uint32(/* id 2, wireType 0 =*/16).int64(message.version); + return writer; + }; + + /** + * Encodes the specified OperatorSetIdProto message, length delimited. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + OperatorSetIdProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.OperatorSetIdProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + OperatorSetIdProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.OperatorSetIdProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.domain = reader.string(); + break; + } + case 2: { + message.version = reader.int64(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.OperatorSetIdProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + OperatorSetIdProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an OperatorSetIdProto message. + * @function verify + * @memberof onnx.OperatorSetIdProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + OperatorSetIdProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + if (message.version != null && message.hasOwnProperty("version")) + if (!$util.isInteger(message.version) && !(message.version && $util.isInteger(message.version.low) && $util.isInteger(message.version.high))) + return "version: integer|Long expected"; + return null; + }; + + /** + * Creates an OperatorSetIdProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.OperatorSetIdProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + */ + OperatorSetIdProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.OperatorSetIdProto) + return object; + var message = new $root.onnx.OperatorSetIdProto(); + if (object.domain != null) + message.domain = String(object.domain); + if (object.version != null) + if ($util.Long) + (message.version = $util.Long.fromValue(object.version)).unsigned = false; + else if (typeof object.version === "string") + message.version = parseInt(object.version, 10); + else if (typeof object.version === "number") + message.version = object.version; + else if (typeof object.version === "object") + message.version = new $util.LongBits(object.version.low >>> 0, object.version.high >>> 0).toNumber(); + return message; + }; + + /** + * Creates a plain object from an OperatorSetIdProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.OperatorSetIdProto} message OperatorSetIdProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + OperatorSetIdProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.domain = ""; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.version = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.version = options.longs === String ? "0" : 0; + } + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.version != null && message.hasOwnProperty("version")) + if (typeof message.version === "number") + object.version = options.longs === String ? String(message.version) : message.version; + else + object.version = options.longs === String ? $util.Long.prototype.toString.call(message.version) : options.longs === Number ? new $util.LongBits(message.version.low >>> 0, message.version.high >>> 0).toNumber() : message.version; + return object; + }; + + /** + * Converts this OperatorSetIdProto to JSON. + * @function toJSON + * @memberof onnx.OperatorSetIdProto + * @instance + * @returns {Object.} JSON object + */ + OperatorSetIdProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for OperatorSetIdProto + * @function getTypeUrl + * @memberof onnx.OperatorSetIdProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + OperatorSetIdProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.OperatorSetIdProto"; + }; + + return OperatorSetIdProto; + })(); + + /** + * OperatorStatus enum. + * @name onnx.OperatorStatus + * @enum {number} + * @property {number} EXPERIMENTAL=0 EXPERIMENTAL value + * @property {number} STABLE=1 STABLE value + */ + onnx.OperatorStatus = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "EXPERIMENTAL"] = 0; + values[valuesById[1] = "STABLE"] = 1; + return values; + })(); + + onnx.FunctionProto = (function() { + + /** + * Properties of a FunctionProto. + * @memberof onnx + * @interface IFunctionProto + * @property {string|null} [name] FunctionProto name + * @property {Array.|null} [input] FunctionProto input + * @property {Array.|null} [output] FunctionProto output + * @property {Array.|null} [attribute] FunctionProto attribute + * @property {Array.|null} [attributeProto] FunctionProto attributeProto + * @property {Array.|null} [node] FunctionProto node + * @property {string|null} [docString] FunctionProto docString + * @property {Array.|null} [opsetImport] FunctionProto opsetImport + * @property {string|null} [domain] FunctionProto domain + */ + + /** + * Constructs a new FunctionProto. + * @memberof onnx + * @classdesc Represents a FunctionProto. + * @implements IFunctionProto + * @constructor + * @param {onnx.IFunctionProto=} [properties] Properties to set + */ + function FunctionProto(properties) { + this.input = []; + this.output = []; + this.attribute = []; + this.attributeProto = []; + this.node = []; + this.opsetImport = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * FunctionProto name. + * @member {string} name + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.name = ""; + + /** + * FunctionProto input. + * @member {Array.} input + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.input = $util.emptyArray; + + /** + * FunctionProto output. + * @member {Array.} output + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.output = $util.emptyArray; + + /** + * FunctionProto attribute. + * @member {Array.} attribute + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.attribute = $util.emptyArray; + + /** + * FunctionProto attributeProto. + * @member {Array.} attributeProto + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.attributeProto = $util.emptyArray; + + /** + * FunctionProto node. + * @member {Array.} node + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.node = $util.emptyArray; + + /** + * FunctionProto docString. + * @member {string} docString + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.docString = ""; + + /** + * FunctionProto opsetImport. + * @member {Array.} opsetImport + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.opsetImport = $util.emptyArray; + + /** + * FunctionProto domain. + * @member {string} domain + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.domain = ""; + + /** + * Creates a new FunctionProto instance using the specified properties. + * @function create + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto=} [properties] Properties to set + * @returns {onnx.FunctionProto} FunctionProto instance + */ + FunctionProto.create = function create(properties) { + return new FunctionProto(properties); + }; + + /** + * Encodes the specified FunctionProto message. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. + * @function encode + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + FunctionProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + writer.uint32(/* id 4, wireType 2 =*/34).string(message.input[i]); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + writer.uint32(/* id 5, wireType 2 =*/42).string(message.output[i]); + if (message.attribute != null && message.attribute.length) + for (var i = 0; i < message.attribute.length; ++i) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.attribute[i]); + if (message.node != null && message.node.length) + for (var i = 0; i < message.node.length; ++i) + $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 7, wireType 2 =*/58).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 8, wireType 2 =*/66).string(message.docString); + if (message.opsetImport != null && message.opsetImport.length) + for (var i = 0; i < message.opsetImport.length; ++i) + $root.onnx.OperatorSetIdProto.encode(message.opsetImport[i], writer.uint32(/* id 9, wireType 2 =*/74).fork()).ldelim(); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 10, wireType 2 =*/82).string(message.domain); + if (message.attributeProto != null && message.attributeProto.length) + for (var i = 0; i < message.attributeProto.length; ++i) + $root.onnx.AttributeProto.encode(message.attributeProto[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified FunctionProto message, length delimited. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + FunctionProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a FunctionProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.FunctionProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.FunctionProto} FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + FunctionProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.FunctionProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 4: { + if (!(message.input && message.input.length)) + message.input = []; + message.input.push(reader.string()); + break; + } + case 5: { + if (!(message.output && message.output.length)) + message.output = []; + message.output.push(reader.string()); + break; + } + case 6: { + if (!(message.attribute && message.attribute.length)) + message.attribute = []; + message.attribute.push(reader.string()); + break; + } + case 11: { + if (!(message.attributeProto && message.attributeProto.length)) + message.attributeProto = []; + message.attributeProto.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); + break; + } + case 7: { + if (!(message.node && message.node.length)) + message.node = []; + message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); + break; + } + case 8: { + message.docString = reader.string(); + break; + } + case 9: { + if (!(message.opsetImport && message.opsetImport.length)) + message.opsetImport = []; + message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); + break; + } + case 10: { + message.domain = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a FunctionProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.FunctionProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.FunctionProto} FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + FunctionProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a FunctionProto message. + * @function verify + * @memberof onnx.FunctionProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + FunctionProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) + return "input: array expected"; + for (var i = 0; i < message.input.length; ++i) + if (!$util.isString(message.input[i])) + return "input: string[] expected"; + } + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) + return "output: array expected"; + for (var i = 0; i < message.output.length; ++i) + if (!$util.isString(message.output[i])) + return "output: string[] expected"; + } + if (message.attribute != null && message.hasOwnProperty("attribute")) { + if (!Array.isArray(message.attribute)) + return "attribute: array expected"; + for (var i = 0; i < message.attribute.length; ++i) + if (!$util.isString(message.attribute[i])) + return "attribute: string[] expected"; + } + if (message.attributeProto != null && message.hasOwnProperty("attributeProto")) { + if (!Array.isArray(message.attributeProto)) + return "attributeProto: array expected"; + for (var i = 0; i < message.attributeProto.length; ++i) { + var error = $root.onnx.AttributeProto.verify(message.attributeProto[i]); + if (error) + return "attributeProto." + error; + } + } + if (message.node != null && message.hasOwnProperty("node")) { + if (!Array.isArray(message.node)) + return "node: array expected"; + for (var i = 0; i < message.node.length; ++i) { + var error = $root.onnx.NodeProto.verify(message.node[i]); + if (error) + return "node." + error; + } + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.opsetImport != null && message.hasOwnProperty("opsetImport")) { + if (!Array.isArray(message.opsetImport)) + return "opsetImport: array expected"; + for (var i = 0; i < message.opsetImport.length; ++i) { + var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); + if (error) + return "opsetImport." + error; + } + } + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + return null; + }; + + /** + * Creates a FunctionProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.FunctionProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.FunctionProto} FunctionProto + */ + FunctionProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.FunctionProto) + return object; + var message = new $root.onnx.FunctionProto(); + if (object.name != null) + message.name = String(object.name); + if (object.input) { + if (!Array.isArray(object.input)) + throw TypeError(".onnx.FunctionProto.input: array expected"); + message.input = []; + for (var i = 0; i < object.input.length; ++i) + message.input[i] = String(object.input[i]); + } + if (object.output) { + if (!Array.isArray(object.output)) + throw TypeError(".onnx.FunctionProto.output: array expected"); + message.output = []; + for (var i = 0; i < object.output.length; ++i) + message.output[i] = String(object.output[i]); + } + if (object.attribute) { + if (!Array.isArray(object.attribute)) + throw TypeError(".onnx.FunctionProto.attribute: array expected"); + message.attribute = []; + for (var i = 0; i < object.attribute.length; ++i) + message.attribute[i] = String(object.attribute[i]); + } + if (object.attributeProto) { + if (!Array.isArray(object.attributeProto)) + throw TypeError(".onnx.FunctionProto.attributeProto: array expected"); + message.attributeProto = []; + for (var i = 0; i < object.attributeProto.length; ++i) { + if (typeof object.attributeProto[i] !== "object") + throw TypeError(".onnx.FunctionProto.attributeProto: object expected"); + message.attributeProto[i] = $root.onnx.AttributeProto.fromObject(object.attributeProto[i]); + } + } + if (object.node) { + if (!Array.isArray(object.node)) + throw TypeError(".onnx.FunctionProto.node: array expected"); + message.node = []; + for (var i = 0; i < object.node.length; ++i) { + if (typeof object.node[i] !== "object") + throw TypeError(".onnx.FunctionProto.node: object expected"); + message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); + } + } + if (object.docString != null) + message.docString = String(object.docString); + if (object.opsetImport) { + if (!Array.isArray(object.opsetImport)) + throw TypeError(".onnx.FunctionProto.opsetImport: array expected"); + message.opsetImport = []; + for (var i = 0; i < object.opsetImport.length; ++i) { + if (typeof object.opsetImport[i] !== "object") + throw TypeError(".onnx.FunctionProto.opsetImport: object expected"); + message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); + } + } + if (object.domain != null) + message.domain = String(object.domain); + return message; + }; + + /** + * Creates a plain object from a FunctionProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.FunctionProto + * @static + * @param {onnx.FunctionProto} message FunctionProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + FunctionProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.input = []; + object.output = []; + object.attribute = []; + object.node = []; + object.opsetImport = []; + object.attributeProto = []; + } + if (options.defaults) { + object.name = ""; + object.docString = ""; + object.domain = ""; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = message.input[j]; + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = message.output[j]; + } + if (message.attribute && message.attribute.length) { + object.attribute = []; + for (var j = 0; j < message.attribute.length; ++j) + object.attribute[j] = message.attribute[j]; + } + if (message.node && message.node.length) { + object.node = []; + for (var j = 0; j < message.node.length; ++j) + object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.opsetImport && message.opsetImport.length) { + object.opsetImport = []; + for (var j = 0; j < message.opsetImport.length; ++j) + object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); + } + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.attributeProto && message.attributeProto.length) { + object.attributeProto = []; + for (var j = 0; j < message.attributeProto.length; ++j) + object.attributeProto[j] = $root.onnx.AttributeProto.toObject(message.attributeProto[j], options); + } + return object; + }; + + /** + * Converts this FunctionProto to JSON. + * @function toJSON + * @memberof onnx.FunctionProto + * @instance + * @returns {Object.} JSON object + */ + FunctionProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for FunctionProto + * @function getTypeUrl + * @memberof onnx.FunctionProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + FunctionProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.FunctionProto"; + }; + + return FunctionProto; + })(); + + return onnx; +})(); + +module.exports = $root; diff --git a/js/node/test/test-utils.ts b/js/node/test/test-utils.ts index 968e8a1881810..3eef90356a335 100644 --- a/js/node/test/test-utils.ts +++ b/js/node/test/test-utils.ts @@ -4,10 +4,11 @@ import assert from 'assert'; import * as fs from 'fs-extra'; import {jsonc} from 'jsonc'; -import * as onnx_proto from 'onnx-proto'; import {InferenceSession, Tensor} from 'onnxruntime-common'; import * as path from 'path'; +import * as onnx_proto from './ort-schema/protobuf/onnx'; + export const TEST_ROOT = __dirname; export const TEST_DATA_ROOT = path.join(TEST_ROOT, 'testdata'); diff --git a/js/node/tsconfig.json b/js/node/tsconfig.json index 417be5c373775..c154c3e148ed0 100644 --- a/js/node/tsconfig.json +++ b/js/node/tsconfig.json @@ -1,7 +1,6 @@ { "extends": "../tsconfig.json", "compilerOptions": { - "module": "Node16", "outDir": "dist" }, "include": ["lib"] diff --git a/js/package-lock.json b/js/package-lock.json index c87a58a3196d6..c16a8b59a3a6f 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -3391,9 +3391,9 @@ } }, "node_modules/normalize-package-data/node_modules/semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true, "bin": { "semver": "bin/semver" @@ -7011,9 +7011,9 @@ }, "dependencies": { "semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true } } diff --git a/js/tsconfig.json b/js/tsconfig.json index 41c94b8afe86b..faf9066b3d96e 100644 --- a/js/tsconfig.json +++ b/js/tsconfig.json @@ -1,6 +1,6 @@ { "compilerOptions": { - "module": "ES2015", + "module": "Node16", "moduleResolution": "Node16", "esModuleInterop": true, "target": "ES2020", @@ -10,7 +10,7 @@ "noImplicitAny": true, "noImplicitReturns": true, "noImplicitThis": true, - "noUnusedParameters": false, + "noUnusedParameters": true, "alwaysStrict": true, "strictNullChecks": true, "pretty": true, diff --git a/js/tsconfig.tools.json b/js/tsconfig.tools.json index 366e0d85b5b80..a70ca0388034d 100644 --- a/js/tsconfig.tools.json +++ b/js/tsconfig.tools.json @@ -1,7 +1,6 @@ { "extends": "./tsconfig.json", "compilerOptions": { - "module": "Node16", "declaration": false, "sourceMap": false } diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 0b82a9c031baa..b246e19137888 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -20,6 +20,7 @@ Do not modify directly.* | Asinh | ai.onnx(9+) | | | Atan | ai.onnx(7+) | | | Atanh | ai.onnx(9+) | | +| Attention | com.microsoft(1+) | need implementing mask and past/present | | AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(7-9,10,11+) | need perf optimization; need implementing activation | | BiasAdd | com.microsoft(1+) | | | BiasSplitGelu | com.microsoft(1+) | | @@ -61,6 +62,7 @@ Do not modify directly.* | MemcpyFromHost | ai.onnx(1+) | | | MemcpyToHost | ai.onnx(1+) | | | Mul | ai.onnx(7-12,13,14+) | | +| MultiHeadAttention | com.microsoft(1+) | need implementing mask and past/present | | Neg | ai.onnx(6-12,13+) | | | Not | ai.onnx(1+) | | | Pad | ai.onnx(2-10,11-12,13-17,18,19+) | | diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts index 8b14b57acc062..fb714bf5996f1 100644 --- a/js/web/lib/build-def.d.ts +++ b/js/web/lib/build-def.d.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -/* eslint-disable @typescript-eslint/no-unused-vars, @typescript-eslint/naming-convention */ +/* eslint-disable @typescript-eslint/naming-convention */ /** * The interface BuildDefinitions contains a set of flags which are defined at build time. diff --git a/js/web/lib/onnxjs/attribute-with-cache-key.ts b/js/web/lib/onnxjs/attribute-with-cache-key.ts index 6608b00471e77..5d47570f267a6 100644 --- a/js/web/lib/onnxjs/attribute-with-cache-key.ts +++ b/js/web/lib/onnxjs/attribute-with-cache-key.ts @@ -6,13 +6,13 @@ class AttributeWithCacheKeyImpl { Object.assign(this, attribute); } - private _cacheKey: string; + private key: string; public get cacheKey(): string { - if (!this._cacheKey) { - this._cacheKey = + if (!this.key) { + this.key = Object.getOwnPropertyNames(this).sort().map(name => `${(this as Record)[name]}`).join(';'); } - return this._cacheKey; + return this.key; } } diff --git a/js/web/lib/onnxjs/backends/webgl/glsl-coordinate-lib.ts b/js/web/lib/onnxjs/backends/webgl/glsl-coordinate-lib.ts index dd3f1b30dfb46..1f2b27c7bdea8 100644 --- a/js/web/lib/onnxjs/backends/webgl/glsl-coordinate-lib.ts +++ b/js/web/lib/onnxjs/backends/webgl/glsl-coordinate-lib.ts @@ -186,7 +186,7 @@ export class CoordsGlslLib extends GlslLib { /** * 1D packed output coordinates. */ - protected getOutputPacked1DCoords(shape: [number], texShape: [number, number]): GlslLibRoutine { + protected getOutputPacked1DCoords(_shape: [number], texShape: [number, number]): GlslLibRoutine { const packedTexShape = texShape; let source = ''; if (packedTexShape[0] === 1) { @@ -331,7 +331,7 @@ export class CoordsGlslLib extends GlslLib { /** * Unpacked 1D output coordinates. */ - protected getOutputUnpacked1DCoords(shape: [number], texShape: [number, number]): GlslLibRoutine { + protected getOutputUnpacked1DCoords(_shape: [number], texShape: [number, number]): GlslLibRoutine { const source = ` int getOutputCoords() { ivec2 resTexRC = ivec2(TexCoords.xy * @@ -641,7 +641,7 @@ export class CoordsGlslLib extends GlslLib { if (outRank < 2 && inRank > 0) { unpackedCoordsSnippet = 'coords'; } else { - unpackedCoordsSnippet = inShape.map((s, i) => `coords.${fields[i + rankDiff]}`).join(', '); + unpackedCoordsSnippet = inShape.map((_s, i) => `coords.${fields[i + rankDiff]}`).join(', '); } let output = 'return outputValue;'; @@ -734,7 +734,7 @@ export class CoordsGlslLib extends GlslLib { if (outRank < 2 && inRank > 0) { unpackedCoordsSnippet = 'coords'; } else { - unpackedCoordsSnippet = inputLayout.unpackedShape.map((s, i) => `coords.${fields[i + rankDiff]}`).join(', '); + unpackedCoordsSnippet = inputLayout.unpackedShape.map((_s, i) => `coords.${fields[i + rankDiff]}`).join(', '); } const source = ` float ${funcName}() { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/concat-packed.ts b/js/web/lib/onnxjs/backends/webgl/ops/concat-packed.ts index 709f883ae12c9..d0e589a428825 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/concat-packed.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/concat-packed.ts @@ -12,7 +12,7 @@ import {getChannels, unpackFromChannel} from './packing-utils'; const createPackedConcatProgramMetadata = (inputCount: number, cacheHint: string) => ({ name: 'Concat (packed)', - inputNames: Array.from({length: inputCount}, (v, i) => `X${i}`), + inputNames: Array.from({length: inputCount}, (_v, i) => `X${i}`), inputTypes: Array(inputCount).fill(TextureType.packed), cacheHint }); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/concat.ts b/js/web/lib/onnxjs/backends/webgl/ops/concat.ts index c2b18ef86f814..f85f4032feae1 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/concat.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/concat.ts @@ -30,13 +30,13 @@ export const concat: OperatorImplementation = const createUnpackedConcatProgramMetadata = (inputCount: number, cacheHint: string) => ({ name: 'Concat', - inputNames: Array.from({length: inputCount}, (v, i) => `X${i}`), + inputNames: Array.from({length: inputCount}, (_v, i) => `X${i}`), inputTypes: Array(inputCount).fill(TextureType.unpacked), cacheHint }); const createUnpackedConcatProgramInfo = - (handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], axis: number): ProgramInfo => { + (_handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], axis: number): ProgramInfo => { const inputShape = inputs[0].dims.slice(); if (axis >= inputShape.length || axis < (-1 * inputShape.length)) { throw new Error('axis specified for concat doesn\'t match input dimensionality'); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/gather.ts b/js/web/lib/onnxjs/backends/webgl/ops/gather.ts index 54b6ccd1a3685..bb44a20d75f34 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/gather.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/gather.ts @@ -30,7 +30,7 @@ const gatherProgramMetadata = { }; const createGatherProgramInfo = - (handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], axis: number): ProgramInfo => { + (_handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], axis: number): ProgramInfo => { const inputShape = inputs[0].dims.slice(); const indexDataShape = inputs[1].dims.slice(); const outputShape = new Array(inputShape.length + indexDataShape.length - 1); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/im2col.ts b/js/web/lib/onnxjs/backends/webgl/ops/im2col.ts index f74c35b612665..a1da13ec48d70 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/im2col.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/im2col.ts @@ -15,7 +15,7 @@ const createIm2ColProgramMetadata = (cacheHint: string) => ({ }); const createIm2ColProgramInfo = - (inferenceHandler: WebGLInferenceHandler, metadata: ProgramMetadata, x: Tensor, w: Tensor, + (_inferenceHandler: WebGLInferenceHandler, metadata: ProgramMetadata, x: Tensor, w: Tensor, outputShape: readonly number[], attributes: ConvAttributes): ProgramInfo => { const xshape = x.dims; const wshape = w.dims; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts b/js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts index 1cd5288251433..efc79f686c960 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts @@ -35,7 +35,7 @@ const imageScalerProgramMetadata = { }; const createImageScalerProgramInfo = - (handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], attributes: ImageScalerAttributes): + (_handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], attributes: ImageScalerAttributes): ProgramInfo => { const outputShape = inputs[0].dims.slice(); const rank = outputShape.length; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts b/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts index fb3c2357ae8fe..0be6d1ba8bcd2 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts @@ -107,10 +107,10 @@ function getBcastSamplerForMatmul( const rankADiff = outRank - inARank; const rankBDiff = outRank - inBRank; - unpackedACoordsSnippet = inAShape.map((s, i) => `coords.${allGlChannels[i + rankADiff]}`); + unpackedACoordsSnippet = inAShape.map((_s, i) => `coords.${allGlChannels[i + rankADiff]}`); unpackedACoordsSnippet[inARank - 1] = 'i*2'; unpackedACoordsSnippet.join(', '); - unpackedBCoordsSnippet = inBShape.map((s, i) => `coords.${allGlChannels[i + rankBDiff]}`); + unpackedBCoordsSnippet = inBShape.map((_s, i) => `coords.${allGlChannels[i + rankBDiff]}`); unpackedBCoordsSnippet[inBRank - 2] = 'i*2'; unpackedBCoordsSnippet.join(', '); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/matmul.ts b/js/web/lib/onnxjs/backends/webgl/ops/matmul.ts index 704128fb4858e..523165f29f852 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/matmul.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/matmul.ts @@ -117,7 +117,7 @@ export function getBiasForMatmul( if (outRank < 2 && inRank > 0) { unpackedCoordsSnippet = 'coords'; } else { - unpackedCoordsSnippet = inShape.map((s, i) => `coords.${allGlChannels[i + rankDiff]}`).join(', '); + unpackedCoordsSnippet = inShape.map((_s, i) => `coords.${allGlChannels[i + rankDiff]}`).join(', '); } const broadcastDims = BroadcastUtil.getBroadcastDims(inShape, outShape); const coordsSnippet = broadcastDims.map(d => `coords.${allGlChannels[d + rankDiff]} = 0;`).join('\n'); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts b/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts index 1a2bc7422c833..c9ea460a6f1fc 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts @@ -46,7 +46,7 @@ export const parseReduceAttributes: OperatorInitialization = ( }; const createReduceProgramInfo = - (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, name: string, reduceOp: ReduceOp, + (_handler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, _name: string, reduceOp: ReduceOp, reduceProgramMetadata: ProgramMetadata): ProgramInfo => { const outputShape: number[] = []; const iRank = inputs[0].dims.length || 1; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/shape.ts b/js/web/lib/onnxjs/backends/webgl/ops/shape.ts index 51acf5042d8bd..c2d703ed04fa0 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/shape.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/shape.ts @@ -4,7 +4,7 @@ import {Tensor} from '../../../tensor'; import {WebGLInferenceHandler} from '../inference-handler'; -export const shape = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => { +export const shape = (_inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => { validateInputs(inputs); return [new Tensor([inputs[0].dims.length], 'int32', undefined, undefined, new Int32Array(inputs[0].dims))]; }; @@ -13,4 +13,4 @@ const validateInputs = (inputs: Tensor[]): void => { if (!inputs || inputs.length !== 1) { throw new Error('Shape requires 1 input.'); } -}; \ No newline at end of file +}; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/slice.ts b/js/web/lib/onnxjs/backends/webgl/ops/slice.ts index d32a76bbc8628..81fc1b7076fdb 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/slice.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/slice.ts @@ -42,8 +42,8 @@ export const parseSliceAttributes: OperatorInitialization = (no }; const createSliceProgramInfo = - (inferenceHandler: WebGLInferenceHandler, input: Tensor, attributes: SliceAttributes): ProgramInfo => { - const axes = (attributes.axes.length === 0) ? input.dims.slice(0).map((val, i) => i) : attributes.axes; + (_inferenceHandler: WebGLInferenceHandler, input: Tensor, attributes: SliceAttributes): ProgramInfo => { + const axes = (attributes.axes.length === 0) ? input.dims.slice(0).map((_val, i) => i) : attributes.axes; const normalizedAxes = ShapeUtil.normalizeAxes(axes, input.dims.length); const starts = attributes.starts.map((start, i) => { if (start > input.dims[normalizedAxes[i]] - 1) { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/split.ts b/js/web/lib/onnxjs/backends/webgl/ops/split.ts index d1bd00d47eebd..2ab14563d80e2 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/split.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/split.ts @@ -49,13 +49,13 @@ export const parseSplitAttributes: OperatorInitialization = (no }; const getProgramCount = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], axis: number, attributes: SplitAttributes): number => { + (_inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], axis: number, attributes: SplitAttributes): number => { const [, offsets] = SplitUtil.splitShape(inputs[0].dims, axis, attributes.split, attributes.numOutputs); return offsets.length; }; const createSplitProgramInfo = - (inferenceHandler: WebGLInferenceHandler, input: Tensor, attributes: SplitAttributes, axis: number, index: number): + (_inferenceHandler: WebGLInferenceHandler, input: Tensor, attributes: SplitAttributes, axis: number, index: number): ProgramInfo => { const [shapes, offsets] = SplitUtil.splitShape(input.dims, axis, attributes.split, attributes.numOutputs); const offset = offsets[index]; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/sum.ts b/js/web/lib/onnxjs/backends/webgl/ops/sum.ts index c05286d16f936..2c25b10c5872c 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/sum.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/sum.ts @@ -11,7 +11,7 @@ export const sum = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): const sumProgramMetadata = { name: 'Sum', - inputNames: inputs.map((v, i) => `X${i}`), + inputNames: inputs.map((_v, i) => `X${i}`), inputTypes: new Array(inputs.length).fill(TextureType.unpacked) }; @@ -24,7 +24,7 @@ const createSumProgramInfo = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], sumProgramMetadata: ProgramMetadata): ProgramInfo => { const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); const outputShape = inputs[0].dims.slice(); - const sumLine = inputs.map((v, i) => `${glsl.texture2D}(X${i},TexCoords)`).join(' + '); + const sumLine = inputs.map((_v, i) => `${glsl.texture2D}(X${i},TexCoords)`).join(' + '); const shaderSource = ` void main() { vec4 result = ${sumLine}; @@ -65,4 +65,4 @@ const validateInputs = (inputs: Tensor[]): void => { throw new Error('Input types are not matched.'); } } -}; \ No newline at end of file +}; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/tile.ts b/js/web/lib/onnxjs/backends/webgl/ops/tile.ts index 42128c7abc48c..1d2cba7d9d75f 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/tile.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/tile.ts @@ -22,7 +22,7 @@ export const tile = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): }; const createTileProgramInfo = - (handler: WebGLInferenceHandler, inputs: Tensor[], tileProgramMetadata: ProgramMetadata): ProgramInfo => { + (_handler: WebGLInferenceHandler, inputs: Tensor[], tileProgramMetadata: ProgramMetadata): ProgramInfo => { const inputShape = inputs[0].dims.slice(); const outputShape = new Array(inputShape.length); @@ -63,4 +63,4 @@ const validateInputs = (inputs: Tensor[]): void => { if (inputs[1].type !== 'int32' && inputs[1].type !== 'int16') { throw new Error('Invalid repeat type.'); } -}; \ No newline at end of file +}; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/transpose.ts b/js/web/lib/onnxjs/backends/webgl/ops/transpose.ts index 815ff13f1f925..d3e7b3c0823be 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/transpose.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/transpose.ts @@ -36,7 +36,7 @@ export const parseTransposeAttributes: OperatorInitialization createAttributeWithCacheKey({perm: node.attributes.getInts('perm', [])}); const createTransposeProgramInfo = - (inferenceHandler: WebGLInferenceHandler, input: Tensor, perm: number[]): ProgramInfo => { + (_inferenceHandler: WebGLInferenceHandler, input: Tensor, perm: number[]): ProgramInfo => { const inputShape = input.dims; perm = getAdjustedPerm(inputShape, perm); const unpackedOutputShape = getOutputShape(inputShape, perm); diff --git a/js/web/lib/onnxjs/backends/webgl/texture-data-encoder.ts b/js/web/lib/onnxjs/backends/webgl/texture-data-encoder.ts index 5135c63702316..4b0cf3f037921 100644 --- a/js/web/lib/onnxjs/backends/webgl/texture-data-encoder.ts +++ b/js/web/lib/onnxjs/backends/webgl/texture-data-encoder.ts @@ -82,7 +82,7 @@ export class RedFloat32DataEncoder implements DataEncoder { } decode(buffer: Encoder.DataArrayType, dataSize: number): Float32Array { if (this.channelSize === 1) { - const filteredData = (buffer as Float32Array).filter((value, index) => index % 4 === 0).subarray(0, dataSize); + const filteredData = (buffer as Float32Array).filter((_value, index) => index % 4 === 0).subarray(0, dataSize); return filteredData; } return buffer.subarray(0, dataSize) as Float32Array; @@ -119,7 +119,7 @@ export class RGBAFloatDataEncoder implements DataEncoder { } decode(buffer: Encoder.DataArrayType, dataSize: number): Float32Array { if (this.channelSize === 1) { - const filteredData = (buffer as Float32Array).filter((value, index) => index % 4 === 0).subarray(0, dataSize); + const filteredData = (buffer as Float32Array).filter((_value, index) => index % 4 === 0).subarray(0, dataSize); return filteredData; } return buffer.subarray(0, dataSize) as Float32Array; diff --git a/js/web/lib/onnxjs/backends/webgl/texture-layout-strategy.ts b/js/web/lib/onnxjs/backends/webgl/texture-layout-strategy.ts index c89ef3d23638d..f8e370747928c 100644 --- a/js/web/lib/onnxjs/backends/webgl/texture-layout-strategy.ts +++ b/js/web/lib/onnxjs/backends/webgl/texture-layout-strategy.ts @@ -105,7 +105,7 @@ export class PreferLogicalStrategy implements TextureLayoutStrategy { // tensor has 3 rows, we pretend it has 4 rows in order to account for the // fact that the texels containing the third row are half empty. logShape = logShape.map( - (d, i) => i >= logShape.length - 2 ? (logShape[i] % 2 === 0 ? logShape[i] : logShape[i] + 1) : logShape[i]); + (_d, i) => i >= logShape.length - 2 ? (logShape[i] % 2 === 0 ? logShape[i] : logShape[i] + 1) : logShape[i]); // Packed texture height is at least 2 (the channel height of a single // texel). @@ -182,7 +182,7 @@ export function parseAxisParam(axis: number|number[], shape: number[]): number[] const rank = shape.length; // Normalize input - axis = axis == null ? shape.map((s, i) => i) : ([] as number[]).concat(axis); + axis = axis == null ? shape.map((_s, i) => i) : ([] as number[]).concat(axis); // Check for valid range assert( diff --git a/js/web/lib/onnxjs/backends/webgl/texture-manager.ts b/js/web/lib/onnxjs/backends/webgl/texture-manager.ts index 1cb113e5fa630..effb65288dc1c 100644 --- a/js/web/lib/onnxjs/backends/webgl/texture-manager.ts +++ b/js/web/lib/onnxjs/backends/webgl/texture-manager.ts @@ -172,7 +172,7 @@ export class TextureManager { throw new Error(`TensorData type ${dataType} is not supported`); } } - toTextureData(dataType: Tensor.DataType, data: Tensor.NumberType|undefined): Encoder.DataArrayType|undefined { + toTextureData(_dataType: Tensor.DataType, data: Tensor.NumberType|undefined): Encoder.DataArrayType|undefined { if (!data) { return undefined; } diff --git a/js/web/lib/onnxjs/execution-plan.ts b/js/web/lib/onnxjs/execution-plan.ts index b95e639817dbf..5599087ab46f5 100644 --- a/js/web/lib/onnxjs/execution-plan.ts +++ b/js/web/lib/onnxjs/execution-plan.ts @@ -114,7 +114,7 @@ export class ExecutionPlan { // resolve downstream nodes const downstreamNodes = new Set(); - outputList.forEach((output, i) => { + outputList.forEach((_output, i) => { const j = thisOp.node.outputs[i]; for (const currentDownstreamNodeIndex of graphValues[j].to) { const currentDownstreamNode = graphNodes[currentDownstreamNodeIndex]; diff --git a/js/web/lib/onnxjs/instrument.ts b/js/web/lib/onnxjs/instrument.ts index 4c543cab157d7..4f865503d50ec 100644 --- a/js/web/lib/onnxjs/instrument.ts +++ b/js/web/lib/onnxjs/instrument.ts @@ -176,7 +176,7 @@ function createCategorizedLogger(category: string): Logger.CategorizedLogger { // NOTE: argument 'category' is put the last parameter beacause typescript // doesn't allow optional argument put in front of required argument. This // order is different from a usual logging API. -function logInternal(severity: Logger.Severity, content: string, stack: number, category?: string) { +function logInternal(severity: Logger.Severity, content: string, _stack: number, category?: string) { const config = LOGGER_CONFIG_MAP[category || ''] || LOGGER_CONFIG_MAP['']; if (SEVERITY_VALUE[severity] < SEVERITY_VALUE[config.minimalSeverity]) { return; diff --git a/js/web/lib/onnxjs/util.ts b/js/web/lib/onnxjs/util.ts index 0a76d75e79bbf..d697a8b3138cf 100644 --- a/js/web/lib/onnxjs/util.ts +++ b/js/web/lib/onnxjs/util.ts @@ -967,7 +967,7 @@ export class ReduceUtil { const dims = a.dims.slice(0); // if axes is not set, perform reduce on all axes if (axes.length === 0) { - dims.forEach((d, ind) => axes.push(ind)); + dims.forEach((_d, ind) => axes.push(ind)); } // get a temporary broadcastable output shape const outputDims = ReduceUtil.calcReduceShape(dims, axes, true); diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index eb40da048835e..e2c2bc8deccf4 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -54,20 +54,22 @@ const getProgramInputTensorInfoDependencyKey = * program. if the key is the same, the program shader source should be the same, so we can reuse the program. * */ -const getProgramInfoUniqueKey = (programInfo: ProgramInfo, inputTensors: readonly TensorView[]): string => { - // final key format: - // []:||... - let key = programInfo.name; - if (programInfo.shaderCache?.hint) { - key += '[' + programInfo.shaderCache.hint + ']'; - } - key += `:${ - getProgramInputTensorInfoDependencyKey( - inputTensors, - programInfo.shaderCache?.inputDependencies ?? - new Array(inputTensors.length).fill('dims'))}`; - return key; -}; +const getProgramInfoUniqueKey = + (programInfo: ProgramInfo, inputTensors: readonly TensorView[], is1DimensionDispatch: boolean): string => { + // final key format: + // []:is1DimensionDispatch:||... + let key = programInfo.name; + if (programInfo.shaderCache?.hint) { + key += '[' + programInfo.shaderCache.hint + ']'; + } + key += ':' + is1DimensionDispatch + + `:${ + getProgramInputTensorInfoDependencyKey( + inputTensors, + programInfo.shaderCache?.inputDependencies ?? + new Array(inputTensors.length).fill('dims'))}`; + return key; + }; /** * this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as @@ -283,10 +285,6 @@ export class WebGpuBackend { inputDatas[i] = gpuData; } - // get program info - const key = getProgramInfoUniqueKey(program, inputTensorViews); - let artifact = this.programManager.getArtifact(key); - const {outputs, dispatchGroup, programUniforms} = program.getRunData(inputTensorViews); // check output indices @@ -345,6 +343,9 @@ export class WebGpuBackend { let maxAlignmentOfField = 1; programUniforms.forEach(v => { const data = typeof v.data === 'number' ? [v.data] : v.data; + if (data.length === 0) { + return; + } // https://www.w3.org/TR/WGSL/#alignof let baseAlignment: number; switch (data.length) { @@ -404,9 +405,11 @@ export class WebGpuBackend { uniformBufferBinding = {offset: 0, size: currentOffset, buffer: uniformBufferData.buffer}; } - const normalizedDispatchGroup = this.programManager.normalizeDispatchGroupSize(dispatchGroup); - + const is1DimensionDispatch = normalizedDispatchGroup[1] === 1 && normalizedDispatchGroup[2] === 1; + // get program info + const key = getProgramInfoUniqueKey(program, inputTensorViews, is1DimensionDispatch); + let artifact = this.programManager.getArtifact(key); if (!artifact) { artifact = this.programManager.build(program, normalizedDispatchGroup); this.programManager.setArtifact(key, artifact); diff --git a/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts b/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts index adba0fb9d022d..ad56b92c1d869 100644 --- a/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts +++ b/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts @@ -6,13 +6,13 @@ class AttributeWithCacheKeyImpl { Object.assign(this, attribute); } - private _cacheKey: string; + private key: string; public get cacheKey(): string { - if (!this._cacheKey) { - this._cacheKey = + if (!this.key) { + this.key = Object.getOwnPropertyNames(this).sort().map(name => `${(this as Record)[name]}`).join(';'); } - return this._cacheKey; + return this.key; } } diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index a4d51e68b6a25..bac44328d8f44 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax'; +import {attention, parseAttentionAttributes} from './ops/attention'; import {biasAdd} from './ops/bias-add'; import {biasSplitGelu} from './ops/bias-split-gelu'; import * as binaryOps from './ops/binary-op'; @@ -16,6 +17,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm'; import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm'; import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm'; import {matMul} from './ops/matmul'; +import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; import {pad, parsePadAttributes} from './ops/pad'; import * as pool from './ops/pool'; import {range} from './ops/range'; @@ -46,13 +48,13 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Asinh', [unaryOps.asinh]], ['Atan', [unaryOps.atan]], ['Atanh', [unaryOps.atanh]], + ['Attention', [attention, parseAttentionAttributes]], // TODO: support new attributes for AveragePool-10 ['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]], ['BiasAdd', [biasAdd]], ['BiasSplitGelu', [biasSplitGelu]], ['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]], ['Ceil', [unaryOps.ceil]], - ['ClipV10', [unaryOps.clipV10]], ['Clip', [unaryOps.clip]], ['Concat', [concat, parseConcatAttributes]], ['Conv', [conv, parseConvAttributes]], @@ -86,6 +88,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new // TODO: support new attributes for MaxPool-8 and MaxPool-10 ['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]], ['Mul', [binaryOps.mul]], + ['MultiHeadAttention', [multiHeadAttention, parseMultiHeadAttentionAttributes]], ['Neg', [unaryOps.neg]], ['Not', [unaryOps.not]], ['Pad', [pad, parsePadAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts index 6481a6b21d723..a121bf3892a32 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts @@ -19,8 +19,6 @@ // // modified to fit the needs of the project -export declare type Activation = 'linear' | 'relu' | 'prelu' | 'elu' | 'relu6' | 'leakyrelu' | 'sigmoid' | 'gelu'; - export const typeSnippet = (component: number, dataType: string) => { switch (component) { case 1: @@ -36,17 +34,6 @@ export const typeSnippet = (component: number, dataType: string) => { } }; -export const activationFnSnippet = - (activation?: Activation, _hasPreluActivationWeights = false, _packed = false, _coordsLength = 3): string => { - if (!activation) { - return ''; - } - // TODO: add implementations - return ''; - }; - -export const biasActivationSnippet = (hasBias: boolean, activation?: Activation): string => ` +export const biasSnippet = (hasBias: boolean): string => ` ${hasBias ? 'value = value + getBiasByOutputCoords(coords);' : ''} - // TODO uncomment the following line when activation is supported above. - // ${activation ? 'value = activation(value, coords);' : ''} `; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index fbb936a045b9c..089e783d7e22f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -25,15 +25,16 @@ import {ShapeUtil} from '../../../util'; import {ProgramInfo} from '../../types'; import {tensorTypeToWsglStorageType} from '../common'; import {ConvAttributes} from '../conv'; +import {getActivationSnippet} from '../fuse-utils'; -import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; +import {biasSnippet, typeSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; const conv2dCommonSnippet = (isChannelsLast: boolean, fitAOuter: boolean, fitBOuter: boolean, fitInner: boolean, addBias = false, - activation?: Activation, hasPreluActivationWeights = false, innerElementSizeX = 4, innerElementSizeW = 4, - innerElementSize = 4, dataType = 'f32'): string => { + attributes: ConvAttributes, innerElementSizeX = 4, innerElementSizeW = 4, innerElementSize = 4, + dataType = 'f32'): string => { const getXSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: @@ -129,8 +130,9 @@ const conv2dCommonSnippet = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); const bType = isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); + const {activationFunction, applyActivation} = getActivationSnippet(attributes, resType); const userCode = ` - ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} + ${activationFunction} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { ${isChannelsLast ? sampleX : sampleW} } @@ -146,7 +148,8 @@ const conv2dCommonSnippet = var value = valueIn; let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; ${coordResSnippet} - ${biasActivationSnippet(addBias, activation)} + ${biasSnippet(addBias)} + ${applyActivation} setOutputAtCoords(coords[0], coords[1], coords[2], coords[3], value); } }`; @@ -242,8 +245,7 @@ export const createConv2DMatMulProgramInfo = ${declareFunctions} ${ conv2dCommonSnippet( - isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, - attributes.activation.toLowerCase() as Activation, false, elementsSize[0], elementsSize[1], + isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, attributes, elementsSize[0], elementsSize[1], elementsSize[2], t)} ${ isVec4 ? diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index a95d3830f34eb..85cf7bf87f52c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -24,14 +24,14 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo} from '../../types'; import {ConvTransposeAttributes} from '../conv-transpose'; +import {getActivationSnippet} from '../fuse-utils'; -import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; +import {biasSnippet, typeSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; const conv2dTransposeCommonSnippet = - (isChannelsLast: boolean, addBias = false, activation?: Activation, hasPreluActivationWeights = false, - innerElementSize = 4): string => { + (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, innerElementSize = 4): string => { const type = typeSnippet(innerElementSize, 'f32'); const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { @@ -129,9 +129,9 @@ const conv2dTransposeCommonSnippet = return ${type}(0.0); `; - + const {activationFunction, applyActivation} = getActivationSnippet(attributes, type); const userCode = ` - ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} + ${activationFunction} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} { ${isChannelsLast ? sampleA : sampleW} } @@ -146,7 +146,8 @@ const conv2dTransposeCommonSnippet = var value = valueInput; let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; ${coordResSnippet} - ${biasActivationSnippet(addBias, activation)} + ${biasSnippet(addBias)} + ${applyActivation} result[getIndexFromCoords4D(coords, outShape)/${innerElementSize}] = value; } }`; @@ -236,9 +237,7 @@ export const createConv2DTransposeMatMulProgramInfo = const dimBOuter : i32 = ${dimBOuter}; const dimInner : i32 = ${dimInner}; ${declareFunctions} - ${ - conv2dTransposeCommonSnippet( - isChannelsLast, hasBias, attributes.activation.toLowerCase() as Activation, false, innerElementSize)} + ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)} ${ isVec4 ? makeMatMulPackedVec4Source( elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 0a0f29db6a494..335de01c596b7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -440,7 +440,6 @@ export const createMatmulProgramInfo = const dimInner = aShape[aShape.length - 1]; const dimBOuter = bShape[bShape.length - 1]; const isVec4 = dimInner % 4 === 0 && dimBOuter % 4 === 0; - const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, isVec4); // TODO: fine tune size const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; @@ -462,6 +461,7 @@ export const createMatmulProgramInfo = variables.push(output); const inputVariables = [A, B]; const hasBias = inputs.length > 2; + const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); const declareFunctions = matMulReadWriteFnSource(components, hasBias, applyActivation, variables, batchShapes, isChannelsLast); if (hasBias) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts index 2f8d072c7dd4f..b6c6853c8f222 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts @@ -27,11 +27,6 @@ export interface ArgMinMaxAttributes extends AttributeWithCacheKey { selectLastIndex: number; } -const createArgMinMaxAttributesFromInputs = - (inputs: readonly TensorView[], attributes: ArgMinMaxAttributes): ArgMinMaxAttributes => - createAttributeWithCacheKey( - {axis: attributes.axis, keepDims: attributes.keepDims, selectLastIndex: attributes.selectLastIndex}); - export const argMin = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => { validateInputs(context.inputs); const argMinMaxOp: ReduceOp = (input, output, axes) => { @@ -51,12 +46,10 @@ export const argMin = (context: ComputeContext, attributes: ArgMinMaxAttributes) ]; }; - const updatedAttributes: ArgMinMaxAttributes = - context.inputs.length === 1 ? attributes : createArgMinMaxAttributesFromInputs(context.inputs, attributes); context.compute( createReduceProgramInfo( - 'ArgMin', {hint: updatedAttributes.cacheKey}, [context.inputs[0]], argMinMaxOp, [updatedAttributes.axis], - DataType.int64, updatedAttributes.keepDims), + 'ArgMin', {hint: attributes.cacheKey}, [context.inputs[0]], argMinMaxOp, [attributes.axis], DataType.int64, + attributes.keepDims), {inputs: [0]}); }; @@ -79,12 +72,10 @@ export const argMax = (context: ComputeContext, attributes: ArgMinMaxAttributes) ]; }; - const updatedAttributes: ArgMinMaxAttributes = - context.inputs.length === 1 ? attributes : createArgMinMaxAttributesFromInputs(context.inputs, attributes); context.compute( createReduceProgramInfo( - 'argMax', {hint: updatedAttributes.cacheKey}, [context.inputs[0]], argMinMaxOp, [updatedAttributes.axis], - DataType.int64, updatedAttributes.keepDims), + 'argMax', {hint: attributes.cacheKey}, [context.inputs[0]], argMinMaxOp, [attributes.axis], DataType.int64, + attributes.keepDims), {inputs: [0]}); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts new file mode 100644 index 0000000000000..e1f2a47301bfb --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -0,0 +1,635 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType} from '../types'; + +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; + +export const enum AttentionQkvFormat { + unknown, // enum value not set, or depends on qkv projection implementation details + qkvBNSH, // for non-packed qkv, permuted + qkvBSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention + qkvBSN3H, // for TRT fused attention, qkv are packed + qkvBNSHqkvBS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) + qKvBSNHxBSN2H, // for TRT fused cross attention, kv are packed + qkvTNH, // for memory efficient attention, qkv are not packed, and paddings are removed. + qkvTN3H, // for TRT fused attention, qkv are packed and paddings are removed +} + +export const enum AttentionMaskType { + none, // No mask + mask1dKeySeqLen, // [batch_size], key sequence length + mask1dEndStart, // [2 * batch_size] with end positions and start positions + mask1DKeySeqLenStart, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0], + // ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ..., + // key_start[batch_size - 1], key_end[batch_size - 1]] + mask2dDummy, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask. + mask2dKeyPadding, // [batch_size, total_sequence_length] + mask3dAttention, // [batch_size, sequence_length, total_sequence_length] + mask4dMegatron, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length] + maskUnknown +} + +export interface AttentionParameters { + batchSize: number; + sequenceLength: number; + pastSequenceLength: number; + kvSequenceLength: number; + totalSequenceLength: number; + maxSequenceLength: number; + inputHiddenSize: number; + hiddenSize: number; + vHiddenSize: number; + headSize: number; + vHeadSize: number; + numHeads: number; + isUnidirectional: boolean; + pastPresentShareBuffer: boolean; + maskFilterValue: number; + maskType: AttentionMaskType; + scale: number; + broadcastResPosBias: boolean; + passPastInKv: boolean; + qkvFormat: AttentionQkvFormat; +} + +export interface AttentionAttrs { + numHeads: number; + isUnidirectional: number; + maskFilterValue: number; + scale: number; + doRotary: number; + qkvHiddenSizes: number[]; + pastPresentShareBuffer: boolean; +} + +const validateAttentionInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { + // Abbreviation and Meanings: + // B: batch_size + // S: sequence_length (input sequence length of query) + // P: past_sequence_length (past sequence length of key or value) + // L: kv_sequence_length (input sequence length of key or value) + // M: max_sequence_length + // T: total_sequence_length = past_sequence_length + kv_sequence_length + // N: num_heads + // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size + // H_v: v_head_size + // D_i: input hidden size + // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size + // D_v: v_hidden_size = num_heads * v_head_size + + // When past state is used, Q, K and V should have same hidden size (unless we split it into past_key and past_value). + + // Input shapes: + // input (Q/K/V) : (B, S, D_i) + // weights (Q/K/V) : (D_i, D + D + D_v) + // bias (Q/K/V) : (D + D + D_v) + // mask_index : see below + // past (K/V) : (2, B, N, P, H) or NULL + // relative_position_bias : (B, N, S, T) or NULL + + // For mask_index, the following shapes are supported: + // NULL, (B, 1), (1, 1) + // (B), (2 * B), (3 * B + 2) + // (B, T) + // (B, S, T) + // (B, 1, M, M) + // + // When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger + // than hidden dimension of Q, K and V. + + const input = inputs[0]; + const weights = inputs[1]; + const bias = inputs[2]; + const maskIndex = inputs[3]; + const past = inputs[4]; + const relativePositionBias = inputs[5]; + + if (past && relativePositionBias) { + throw new Error('Attention cannot have both past and relative_position_bias'); + } + + if (input.dims.length !== 3) { + throw new Error('Input "input" must have 3 dimensions'); + } + + const batchSize = input.dims[0]; + const sequenceLength = input.dims[1]; + const inputHiddenSize = input.dims[2]; + + if (bias.dims.length !== 1) { + throw new Error('Input "bias" is expected to have 1 dimensions'); + } + + if (weights.dims.length !== 2) { + throw new Error('Input "weights" is expected to have 2 dimensions'); + } + + if (weights.dims[0] !== inputHiddenSize) { + throw new Error('Input 1 dimension 0 should have same length as dimension 2 of input 0'); + } + + if (bias.dims[0] !== weights.dims[1]) { + throw new Error('Input "bias" dimension 0 should have same length as dimension 1 of input "weights"'); + } + + let qHiddenSize = bias.dims[0] / 3; + let kHiddenSize = qHiddenSize; + let vHiddenSize = kHiddenSize; + if (attributes.qkvHiddenSizes.length > 0) { + if (attributes.qkvHiddenSizes.length !== 3) { + throw new Error('qkv_hidden_sizes attribute should have 3 elements'); + } + for (const sz of attributes.qkvHiddenSizes) { + if (sz % attributes.numHeads !== 0) { + throw new Error('qkv_hidden_sizes should be divisible by num_heads'); + } + } + + qHiddenSize = attributes.qkvHiddenSizes[0]; + kHiddenSize = attributes.qkvHiddenSizes[1]; + vHiddenSize = attributes.qkvHiddenSizes[2]; + } + + const kvSequenceLength = sequenceLength; + + if (qHiddenSize !== kHiddenSize) { + throw new Error('qkv_hidden_sizes first element should be same as the second'); + } + + if (bias.dims[0] !== qHiddenSize + kHiddenSize + vHiddenSize) { + throw new Error('Input "bias" dimension 0 should have same length as sum of Q/K/V hidden sizes'); + } + + let pastSequenceLength = 0; + if (past) { + if (kHiddenSize !== vHiddenSize) { + throw new Error('Input "past" expect k_hidden_size == v_hidden_size'); + } + if (past.dims.length !== 5) { + throw new Error('Input "past" must have 5 dimensions'); + } + if (past.dims[0] !== 2) { + throw new Error('Input "past" first dimension must be 2'); + } + if (past.dims[1] !== batchSize) { + throw new Error('Input "past" second dimension must be batch_size'); + } + if (past.dims[2] !== attributes.numHeads) { + throw new Error('Input "past" third dimension must be num_heads'); + } + if (past.dims[4] !== kHiddenSize / attributes.numHeads) { + throw new Error('Input "past" fifth dimension must be k_hidden_size / num_heads'); + } + + if (!attributes.pastPresentShareBuffer) { + pastSequenceLength = past.dims[3]; + } + // TODO: handle past_seq_len + } + + const totalSequenceLength = kvSequenceLength + pastSequenceLength; + const maxSequenceLength = -1; + + const maskType = AttentionMaskType.none; + if (maskIndex) { + // maskType = AttentionMaskType.MASK_UNKNOWN; + // TODO: handle mask + throw new Error('Mask not supported'); + } + + if (past) { + throw new Error('past is not supported'); + } + if (relativePositionBias) { + throw new Error('relativePositionBias is not supported'); + } + + return { + batchSize, + sequenceLength, + pastSequenceLength, + kvSequenceLength, + totalSequenceLength, + maxSequenceLength, + inputHiddenSize, + hiddenSize: qHiddenSize, + vHiddenSize, + headSize: Math.floor(qHiddenSize / attributes.numHeads), + vHeadSize: Math.floor(vHiddenSize / attributes.numHeads), + numHeads: attributes.numHeads, + isUnidirectional: false, + pastPresentShareBuffer: false, + maskFilterValue: attributes.maskFilterValue, + maskType, + scale: attributes.scale, + broadcastResPosBias: false, + passPastInKv: false, + qkvFormat: AttentionQkvFormat.qkvBNSH, + }; +}; + +export const parseAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => + createAttributeWithCacheKey({...attributes}); + +export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView, n: number, d: number) => { + const components = getMaxComponents(d); + const inputHelper = outputVariable('x', input.dataType, input.dims, components); + + let threadMaxValue = 'threadMaxVector'; + if (components === 2) { + threadMaxValue = 'max(threadMaxVector.x, threadMaxVector.y)'; + } else if (components === 4) { + threadMaxValue = 'max(max(threadMaxVector.x, threadMaxVector.y), max(threadMaxVector.z, threadMaxVector.w))'; + } + const dataType = tensorTypeToWsglStorageType(input.dataType); + let WG = 64; + const dComp = d / components; + if (dComp < WG) { + WG = 1; + } else if (dComp / 8 < 64) { + WG = Math.ceil(dComp / 8); + } + const elementsPerWG = Math.ceil(d / components / WG); + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const dInv: ${dataType} = 1 / ${d}; + const dComp = ${d / components}; + var wgMax: array; + var wgSum: array; + + ${shaderHelper.declareVariables(inputHelper)} + @compute @workgroup_size(${WG}, 1, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_index) local_index : u32) { + let localOffset = local_index * ${elementsPerWG}; + let offset: u32 = workgroup_id.x * dComp + localOffset; + + var threadMaxVector = ${fillVector('f32', components, '-3.402823e+38f')}; + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + threadMaxVector = max(${castToF32(dataType, components, 'x[offset + i]')}, threadMaxVector); + } + wgMax[local_index] = ${threadMaxValue}; + workgroupBarrier(); + + var maxValue = -3.402823e+38f; + for (var i = 0u; i < ${WG}; i++) { + maxValue = max(wgMax[i], maxValue); + } + + var sumVector = ${fillVector('f32', components, '0')}; + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + sumVector += exp(${castToF32(dataType, components, 'x[offset + i]')} - maxValue); + } + wgSum[local_index] = ${sumVector('sumVector', components)}; + workgroupBarrier(); + + var sum: f32 = 0; + for (var i = 0u; i < ${WG}; i++) { + sum += wgSum[i]; + } + + if (sum == 0) { + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + x[offset + i] = ${fillVector(dataType, components, 'dInv')}; + } + } else { + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + let f32input = ${castToF32(dataType, components, 'x[offset + i]')}; + x[offset + i] = ${inputHelper.type.value}(exp(f32input - maxValue) / sum); + } + } + }`; + + context.compute( + { + name: 'AttentionProbsSoftmax', + shaderCache: {hint: `${d}`}, + getShaderSource, + getRunData: () => ({ + outputs: [], + dispatchGroup: {x: n}, + }), + }, + {inputs: [input], outputs: []}); +}; + +const computeAttentionProbs = + (context: ComputeContext, q: TensorView, key: TensorView, _bias: TensorView|undefined, + parameters: AttentionParameters, attributes: AttentionAttrs) => { + const probsShape = [ + parameters.batchSize, parameters.numHeads, parameters.sequenceLength, + parameters.kvSequenceLength + parameters.pastSequenceLength + ]; + // TODO: handle mask + + const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; + + const dataType = tensorTypeToWsglStorageType(q.dataType); + + const components = getMaxComponents(parameters.headSize); + const qInput = inputVariable('q', q.dataType, q.dims, components); + const kInput = inputVariable('key', key.dataType, key.dims, components); + const output = outputVariable('output', q.dataType, probsShape); + + const vectorizedHeadSize = parameters.headSize / components; + const M = parameters.sequenceLength; + const N = parameters.totalSequenceLength; + const K = vectorizedHeadSize; + + const TILE_SIZE = 12; + + const dispatch = { + x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE), + y: Math.ceil(parameters.sequenceLength / TILE_SIZE), + z: parameters.batchSize * parameters.numHeads + }; + + const inputs = [q, key]; + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const M: u32 = ${M}u; + const N: u32 = ${N}u; + const K: u32 = ${K}u; + const alpha: ${dataType} = ${alpha}; + const beta: ${dataType} = 1.0; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + var tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + + ${shaderHelper.declareVariables(qInput, kInput, output)} + + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; + + // x holds the N and y holds the M + let headIdx = workgroup_id.z; + let m = workgroup_id.y * TILE_SIZE; + let n = workgroup_id.x * TILE_SIZE; + let lm = m + local_id.y; + let ln = n + local_id.x; + + let qOffset = ${parameters.sequenceLength * vectorizedHeadSize} * headIdx + m * K; + let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * headIdx + n * K; + + var value = ${fillVector(dataType, components)}; + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m + local_id.y < M && w + local_id.x < K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * K + w + local_id.x]; + } + if (n + local_id.y < N && w + local_id.x < K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * K + w + local_id.x]; + } + workgroupBarrier(); + + for (var k: u32 = 0u; k ({ + outputs: [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: dispatch, + }), + getShaderSource, + }, + {inputs, outputs: [-1]})[0]; + + computeInPlaceSoftmax( + context, probs, parameters.batchSize * parameters.numHeads * parameters.sequenceLength, + parameters.totalSequenceLength); + + return probs; + }; + +const computeVxAttentionScore = + (context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => { + const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize]; + + const probsHelper = inputVariable('probs', probs.dataType, probs.dims); + const vHelper = inputVariable('v', v.dataType, v.dims); + const output = outputVariable('output', probs.dataType, outputShape); + + const dataType = tensorTypeToWsglStorageType(probs.dataType); + + const TILE_SIZE = 12; + const dispatch = { + x: Math.ceil(params.vHeadSize / TILE_SIZE), + y: Math.ceil(params.sequenceLength / TILE_SIZE), + z: params.batchSize * params.numHeads + }; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const M: u32 = ${params.sequenceLength}u; + const N: u32 = ${params.vHeadSize}u; + const K: u32 = ${params.totalSequenceLength}u; + const numHeads: u32 = ${params.numHeads}u; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileQ: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + var tileK: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + + ${shaderHelper.declareVariables(probsHelper, vHelper, output)} + + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; + + let headIdx = workgroup_id.z; + let m = workgroup_id.y * TILE_SIZE + local_id.y; + let n = workgroup_id.x * TILE_SIZE + local_id.x; + + let offsetA = headIdx * (M * K) + m * K; + let offsetB = headIdx * (N * K) + n; + + var value = ${dataType}(0); + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m < M && w + local_id.x < K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x]; + } + if (n < N && w + local_id.y < K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * N]; + } + workgroupBarrier(); + for (var k: u32 = 0u; k ({ + outputs: [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: dispatch, + }), + getShaderSource, + }, + {inputs: [probs, v], outputs: [0]})[0]; + }; + +export const applyAttention = + (context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined, + _past: TensorView|undefined, _pastKey: TensorView|undefined, _pastValue: TensorView|undefined, + relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => { + const probs = computeAttentionProbs(context, q, k, relativePositionBias, parameters, attributes); + + computeVxAttentionScore(context, probs, v, parameters); + }; + +const prepare = (context: ComputeContext, parameters: AttentionParameters) => { + const outputShape = [ + parameters.batchSize, + parameters.numHeads, + parameters.sequenceLength, + parameters.headSize, + ]; + + const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); + + const M = parameters.sequenceLength; + const K = parameters.inputHiddenSize; + const N = parameters.headSize; + + const TILE_SIZE = 12; + const dispatch = { + x: Math.ceil(parameters.headSize / TILE_SIZE), + y: Math.ceil(parameters.sequenceLength / TILE_SIZE), + z: parameters.batchSize * parameters.numHeads + }; + + const getShaderSource = () => ` + const M: u32 = ${M}u; + const K: u32 = ${K}u; + const N: u32 = ${N}u; + const numHeads: u32 = ${parameters.numHeads}; + const ldb = ${parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}u; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileInput: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightV: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + + @group(0) @binding(0) var input: array<${dataType}>; + @group(0) @binding(1) var weight: array<${dataType}>; + @group(0) @binding(2) var bias: array<${dataType}>; + @group(0) @binding(3) var outputQ: array<${dataType}>; + @group(0) @binding(4) var outputK: array<${dataType}>; + @group(0) @binding(5) var outputV: array<${dataType}>; + + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; + + let batchIndex = workgroup_id.z / ${parameters.numHeads}; + let headNumber = workgroup_id.z % ${parameters.numHeads}; + let m = workgroup_id.y * TILE_SIZE + local_id.y; + let n = workgroup_id.x * TILE_SIZE + local_id.x; + + let inputOffset = batchIndex * (M * K) + m * K; + let biasOffsetQ = headNumber * ${parameters.headSize}; + let biasOffsetK = ${parameters.hiddenSize} + biasOffsetQ; + let biasOffsetV = ${parameters.hiddenSize} + biasOffsetK; + + var valueQ = ${dataType}(0); + var valueK = ${dataType}(0); + var valueV = ${dataType}(0); + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m < M && w + local_id.x < K) { + tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x]; + } + if (n < N && w + local_id.y < K) { + let offset = n + (w + local_id.y) * ldb; + tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset]; + tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset]; + tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset]; + } + workgroupBarrier(); + for (var k: u32 = 0u; k ({ + outputs: [ + {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, + {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, + {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, + ], + dispatchGroup: dispatch, + }), + getShaderSource, + }, + {inputs, outputs: [-1, -1, -1]}); +}; + +export const attention = (context: ComputeContext, attributes: AttentionAttrs): void => { + const params = validateAttentionInputs(context.inputs, attributes); + + const [q, k, v] = prepare(context, params); + + return applyAttention( + context, q, k, v, context.inputs[4], undefined, undefined, undefined, context.inputs[5], params, attributes); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 0992552ebaf25..c033c0ba05356 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -17,8 +17,9 @@ type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{ const createBinaryOpProgramShader = (shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[], - vectorize: boolean, doBroadcast: boolean, funcCall: BinaryFunctionCall, typeA: number, typeB: number, - typeOutput: number, useShapesUniforms: boolean, additionalImplementation?: string) => { + vectorize: boolean, doBroadcast: boolean, sharedDimensionDivisibleBy4: boolean, funcCall: BinaryFunctionCall, + typeA: number, typeB: number, typeOutput: number, useShapesUniforms: boolean, + additionalImplementation?: string) => { let expressionScalar: BinaryCustomExpression; let expressionVector: BinaryCustomExpression; if (typeof funcCall === 'string') { @@ -42,6 +43,8 @@ const createBinaryOpProgramShader = if (doBroadcast) { const isAOneElement = ShapeUtil.size(dimsA) === 1; const isBOneElement = ShapeUtil.size(dimsB) === 1; + const aLastDimDivisibleBy4 = dimsA.length > 0 && dimsA[dimsA.length - 1] % 4 === 0; + const bLastDimDivisibleBy4 = dimsB.length > 0 && dimsB[dimsB.length - 1] % 4 === 0; if (isAOneElement || isBOneElement) { assignment = output.setByOffset( 'global_idx', @@ -55,7 +58,14 @@ const createBinaryOpProgramShader = let offsetB = ${b.broadcastedIndicesToOffset('outputIndices', output)}; ${ output.setByOffset( - 'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))} + 'global_idx', + expressionVector( + sharedDimensionDivisibleBy4 || aLastDimDivisibleBy4 ? + a.getByOffset('offsetA / 4u') : + `${a.type.value}(${a.getByOffset('offsetA / 4u')}[offsetA % 4u])`, + sharedDimensionDivisibleBy4 || bLastDimDivisibleBy4 ? + b.getByOffset('offsetB / 4u') : + `${b.type.value}(${b.getByOffset('offsetB / 4u')}[offsetB % 4u])`))} `; } } else { @@ -118,9 +128,9 @@ const createBinaryOpProgramInfo = let outputSize = ShapeUtil.size(a.dims); let vectorize = false; + let sharedDimensionDivisibleBy4 = false; // TODO: deal with zero-sized tensors (eg. dims=[1,0]) - const cacheKeyAux = [isBroadcast]; if (isBroadcast) { const calculatedShape = BroadcastUtil.calcShape(a.dims, b.dims, false); @@ -131,8 +141,12 @@ const createBinaryOpProgramInfo = outputSize = ShapeUtil.size(outputShape); const isAOneElement = ShapeUtil.size(a.dims) === 1; const isBOneElement = ShapeUtil.size(b.dims) === 1; + const aLastDimDivisibleBy4 = a.dims.length > 0 && a.dims[a.dims.length - 1] % 4 === 0; + const bLastDimDivisibleBy4 = b.dims.length > 0 && b.dims[b.dims.length - 1] % 4 === 0; cacheKeyAux.push(isAOneElement); cacheKeyAux.push(isBOneElement); + cacheKeyAux.push(aLastDimDivisibleBy4); + cacheKeyAux.push(bLastDimDivisibleBy4); // check whether vectorize can be enabled let sharedDimension = 1; for (let i = 1; i < outputShape.length; i++) { @@ -144,7 +158,10 @@ const createBinaryOpProgramInfo = break; } } - if (sharedDimension % 4 === 0 || isAOneElement || isBOneElement) { + if (sharedDimension % 4 === 0) { + sharedDimensionDivisibleBy4 = true; + vectorize = true; + } else if (isAOneElement || isBOneElement || aLastDimDivisibleBy4 || bLastDimDivisibleBy4) { vectorize = true; } } else { @@ -158,14 +175,11 @@ const createBinaryOpProgramInfo = name, shaderCache: { hint: cacheKey + cacheKeyAux.map((x) => x.toString()).join('_'), - // If the input is scalar then use type instead of dims because useShapesUniforms is false. - inputDependencies: useShapesUniforms ? - ['rank', 'rank'] : - [a.dims.length > 0 ? 'dims' : 'type', b.dims.length > 0 ? 'dims' : 'type'], + inputDependencies: useShapesUniforms ? ['rank', 'rank'] : ['dims', 'dims'], }, getShaderSource: (shaderHelper) => createBinaryOpProgramShader( - shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, a.dataType, b.dataType, - outputDataType, useShapesUniforms, additionalImplementation), + shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, sharedDimensionDivisibleBy4, funcCall, + a.dataType, b.dataType, outputDataType, useShapesUniforms, additionalImplementation), getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 7352517733c3e..014d9d02f6f10 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -258,8 +258,8 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 = return typeof mappedType === 'string' ? mappedType : mappedType[1]; }; -export const createTensorShapeVariables = (dims: readonly number[]): - ProgramUniform[] => [{type: 'uint32', data: dims}, {type: 'uint32', data: ShapeUtil.computeStrides(dims)}]; +export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => + dims.length === 0 ? [] : [{type: 'uint32', data: dims}, {type: 'uint32', data: ShapeUtil.computeStrides(dims)}]; /** * A helper function to get maximum vector size for specified data length @@ -646,6 +646,8 @@ export const outputVariable = (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, false, components); +export type UniformsArrayType = Array<{name: string; type: string}>; + /** * A ShaderHelper is a helper class for generating WGSL code. */ @@ -697,6 +699,7 @@ export interface ShaderHelper { * A helper function to register one uniform. Can be called multiple times to register multiple uniforms. */ registerUniform(name: string, type: string): ShaderHelper; + registerUniforms(nameToTypeMap: UniformsArrayType): ShaderHelper; } class ShaderHelperImpl implements ShaderHelper { @@ -717,11 +720,12 @@ class ShaderHelperImpl implements ShaderHelper { const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3, @builtin(local_invocation_id) local_id : vec3` : `@builtin(local_invocation_index) local_index : u32, - @builtin(workgroup_id) workgroup_id : vec3`; + @builtin(workgroup_id) workgroup_id : vec3, + @builtin(num_workgroups) num_workgroups : vec3`; const globalIdxDefinition = is1DimensionDispatch ? 'let global_idx = global_id.x;' : - `let global_idx = (workgroup_id.z * ${this.normalizedDispatchGroup[0] * this.normalizedDispatchGroup[1]}u + - workgroup_id.y * ${this.normalizedDispatchGroup[0]}u + workgroup_id.x) * ${ + `let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] + + workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${ workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_index;`; return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ}) @@ -732,11 +736,13 @@ class ShaderHelperImpl implements ShaderHelper { private declareVariable(variable: IndicesHelper, bindingIndex: number): string { this.indicesHelpers.push(variable); - if (variable.shape.startsWith('uniforms.')) { - this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: variable.type.indices}); - } - if (variable.strides.startsWith('uniforms.')) { - this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: variable.type.indices}); + if (variable.rank !== 0) { + if (variable.shape.startsWith('uniforms.')) { + this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: variable.type.indices}); + } + if (variable.strides.startsWith('uniforms.')) { + this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: variable.type.indices}); + } } const access = variable.usage === 'input' ? 'read' : 'read_write'; const storageType = variable.type.storage; @@ -752,8 +758,13 @@ class ShaderHelperImpl implements ShaderHelper { return this; } + registerUniforms(additionalUniforms: UniformsArrayType): ShaderHelper { + this.uniforms = this.uniforms.concat(additionalUniforms); + return this; + } + private indicesHelpers: IndicesHelper[] = []; - private uniforms: Array<{name: string; type: string}> = []; + private uniforms: UniformsArrayType = []; private uniformDeclaration(): string { if (this.uniforms.length === 0) { return ''; @@ -805,4 +816,4 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly }; // TODO: remove this limitation once >4D dims are supported by uniform. -export const enableShapesUniforms = (rank: number): boolean => rank <= 4 && rank > 0; +export const enableShapesUniforms = (rank: number): boolean => rank <= 4; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 4b5ca869f0dfb..43cc4a4c080bd 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -4,9 +4,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface ConcatAttributes extends AttributeWithCacheKey { readonly axis: number; @@ -33,9 +33,10 @@ const validateInputs = (inputs: readonly TensorView[]): void => { } }; -const calculateInputIndexImpl = (numberOfTensors: number): string => ` +const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => ` fn calculateInputIndex(index: u32) -> u32 { - for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) { + let sizeInConcatAxis = array(${sizeInConcatAxisStr}); + for (var i: u32 = 0u; i < ${numberOfTensors}; i += 1u ) { if (index < sizeInConcatAxis[i]) { return i; } @@ -92,40 +93,69 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P const dataType = inputs[0].dataType; let previousSum = 0; + const inputDependencies: ProgramInputTensorInfoDependency[] = []; + const inputShapeOrRanks = []; + const enableInputShapesUniforms = []; + const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; for (let i = 0; i < inputs.length; ++i) { previousSum += inputs[i].dims[adjustedAxis]; sizeInConcatAxis[i] = previousSum; + enableInputShapesUniforms.push(enableShapesUniforms(inputs[i].dims.length)); + inputShapeOrRanks.push(enableInputShapesUniforms[i] ? inputs[i].dims.length : inputs[i].dims); + inputVars[i] = inputVariable(`input${i}`, dataType, inputShapeOrRanks[i]); + inputDependencies.push(enableInputShapesUniforms[i] ? 'rank' : 'dims'); + programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]}); + } + for (let i = 0; i < inputs.length; ++i) { + if (enableInputShapesUniforms[i]) { + programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); + } + } - inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims); + const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); + if (enableOutputShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(outputShape)); } - const output = outputVariable('output', dataType, outputShape); + const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; + const output = outputVariable('output', dataType, outputShapeOrRank); const indicesAxis = output.indicesGet('indices', adjustedAxis); + const sizeInConcatAxisStr = + Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(...inputVars, output)} - const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')}); - ${calculateInputIndexImpl(sizeInConcatAxis.length)} + ${(() => { + shaderHelper.registerUniform('outputSize', 'u32'); + for (let i = 0; i < inputs.length; i++) { + shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32'); + } + return shaderHelper.declareVariables(...inputVars, output); + })()} + + ${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} var indices = ${output.offsetToIndices('global_idx')}; let inputIndex = calculateInputIndex(${indicesAxis}); if (inputIndex != 0u) { + let sizeInConcatAxis = array(${sizeInConcatAxisStr}); ${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u]; } ${assignOutputData(inputVars, output)} }`; + return { name: 'Concat', - shaderCache: {hint: `${axis}`}, + shaderCache: {hint: `${axis}`, inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms, }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 8bfa722dd0909..14482272bad38 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -22,14 +22,13 @@ export const createGroupedConvProgramInfo = const wShape = inputs[1].dims; const outputChannelsPerGroup = wShape[0] / attributes.group; - const {activationFunction, applyActivation} = getActivationSnippet(attributes); - const isChannelLast = attributes.format === 'NHWC'; const outputShape = calculateOutputShape( xShape, wShape, attributes.dilations, attributes.pads, attributes.strides, isChannelLast); const outputSize = ShapeUtil.size(outputShape); const output = outputVariable('output', inputs[0].dataType, outputShape); + const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value); const x = inputVariable('x', inputs[0].dataType, xShape); const w = inputVariable('w', inputs[1].dataType, wShape); const inputVars = [x, w]; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts index 357eb5c0b84ad..a233d37a79e65 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -254,7 +254,7 @@ const createEinsumProgramInfo = (inputs: readonly TensorView[], einsumEquation: ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} var outputIndices = ${output.offsetToIndices('global_idx')}; - ${inputVars.map((inputVar, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')} + ${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')} ${reduceOps.join('\n')}; ${output.setByOffset('global_idx', 'sum')}; }`; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index a920875be1776..0b5c0db2b5112 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -10,28 +10,27 @@ export interface InternalActivationAttributes { readonly activationCacheKey: string; } -export const getActivationSnippet = (attributes: InternalActivationAttributes, isVec4 = false): { - activationFunction: string; applyActivation: string; -} => { - switch (attributes.activation) { - case 'Relu': - return { - activationFunction: '', - applyActivation: isVec4 ? 'value = max(value, vec4(0.0));' : 'value = max(value, 0.0);' - }; - case 'Sigmoid': - return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'}; - case 'Clip': - return { - activationFunction: `const clip_min_=f32(${attributes.clipMin!});const clip_max_=f32(${attributes.clipMax!});`, - applyActivation: isVec4 ? 'value = clamp(value, vec4(clip_min_), vec4(clip_max_));' : - 'value = clamp(value, clip_min_, clip_max_);' - }; - // TODO: adding other activations that can be fused. - default: - return {activationFunction: '', applyActivation: ''}; - } -}; +export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): + {activationFunction: string; applyActivation: string} => { + switch (attributes.activation) { + case 'Relu': + return {activationFunction: '', applyActivation: `value = max(value, ${valueType}(0.0));`}; + case 'Sigmoid': + return { + activationFunction: '', + applyActivation: `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));` + }; + case 'Clip': + return { + activationFunction: `const clip_min_=${valueType}(${attributes.clipMin!});const clip_max_=${valueType}(${ + attributes.clipMax!});`, + applyActivation: 'value = clamp(value, clip_min_, clip_max_);' + }; + // TODO: adding other activations that can be fused. + default: + return {activationFunction: '', applyActivation: ''}; + } + }; export const parseInternalActivationAttributes = (attributes: Record|undefined): InternalActivationAttributes => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index fdcd64abfe4e7..5d6d6debadb9a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -4,9 +4,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; export interface GatherAttributes extends AttributeWithCacheKey { axis: number; @@ -31,20 +31,44 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath const axisDimLimit = inputShape[axis]; const outputSize = ShapeUtil.size(outputShape); - const data = inputVariable('data', inputs[0].dataType, inputs[0].dims); - const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims); - const output = outputVariable('output', inputs[0].dataType, outputShape); + const enableInputShapesUniforms = enableShapesUniforms(inputs[0].dims.length); + const inputShapeOrRank = enableInputShapesUniforms ? inputs[0].dims.length : inputs[0].dims; + const enableIndicesShapesUniforms = enableShapesUniforms(inputs[1].dims.length); + const indicesShapeOrRank = enableIndicesShapesUniforms ? inputs[1].dims.length : inputs[1].dims; + const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); + const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; + + const data = inputVariable('data', inputs[0].dataType, inputShapeOrRank); + const indices = inputVariable('inputIndices', inputs[1].dataType, indicesShapeOrRank); + const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank); + + const programUniforms: ProgramUniform[] = + [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}]; + if (enableInputShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); + } + if (enableIndicesShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + } + if (enableOutputShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(outputShape)); + } + + const inputDependencies: ProgramInputTensorInfoDependency[] = []; + inputDependencies.push(enableInputShapesUniforms ? 'rank' : 'dims'); + inputDependencies.push(enableIndicesShapesUniforms ? 'rank' : 'dims'); + const calcDataIndices = (): string => { const indicesRank = indicesShape.length; let calcStr = `var indicesIndices = ${indices.type.indices}(0);`; for (let i = 0; i < indicesRank; i++) { calcStr += `${indicesRank > 1 ? `indicesIndices[${i}]` : 'indicesIndices'} = ${ - outputShape.length > 1 ? `outputIndices[${axis + i}]` : 'outputIndices'};`; + outputShape.length > 1 ? `outputIndices[uniforms.axis + ${i}]` : 'outputIndices'};`; } calcStr += ` var idx = ${indices.getByIndices('indicesIndices')}; if (idx < 0) { - idx = idx + ${axisDimLimit}; + idx = idx + uniforms.axisDimLimit; } var dataIndices = ${data.type.indices}(0); `; @@ -62,9 +86,13 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath }; const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(data, indices, output)} + ${ + shaderHelper.registerUniform('outputSize', 'u32') + .registerUniform('axisDimLimit', 'i32') + .registerUniform('axis', 'u32') + .declareVariables(data, indices, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} let outputIndices = ${output.offsetToIndices('global_idx')}; ${calcDataIndices()}; let value = ${data.getByIndices('dataIndices')}; @@ -72,12 +100,13 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath }`; return { name: 'Gather', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: {hint: attributes.cacheKey, inputDependencies}, getRunData: () => ({ outputs: [ {dims: outputShape, dataType: inputs[0].dataType}, ], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts new file mode 100644 index 0000000000000..b7726a36bcaad --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts @@ -0,0 +1,335 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType} from '../types'; + +import {applyAttention, AttentionAttrs, AttentionMaskType, AttentionParameters, AttentionQkvFormat} from './attention'; +import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {createTransposeProgramInfo, TransposeAttributes} from './transpose'; + +const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { + const query = inputs[0]; + const key = inputs[1]; + const value = inputs[2]; + const bias = inputs[3]; + const keyPaddingMask = inputs[4]; + const relativePositionBias = inputs[5]; + const pastKey = inputs[6]; + const pastValue = inputs[7]; + + // Abbreviation and Meanings: + // B: batch_size + // S: sequence_length (input sequence length of query) + // P: past_sequence_length (past sequence length of key or value) + // L: kv_sequence_length (input sequence length of key or value) + // M: max_sequence_length + // T: total_sequence_length = past_sequence_length + kv_sequence_length + // N: num_heads + // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size + // H_v: v_head_size + // D_i: input hidden size + // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size + // D_v: v_hidden_size = num_heads * v_head_size + + // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None + // relative_position_bias : (B, 1, S, L) + // past_key : (B, N, S*, H) + // past_value : (B, N, S*, H) + // When no packing for q/k/v: + // query (Q) : (B, S, D) + // key (K) : (B, L, D) or (B, N, S*, H) + // value (V) : (B, L, D_v) or (B, N, S*, H) + // bias (Q/K/V) : (D + D + D_v) + // When packed kv is used: + // query (Q) : (B, S, D) + // key (K) : (B, L, N, 2, H) + // value (V) : None + // bias (Q/K/V) : None + // When packed qkv is used: + // query (Q) : (B, L, N, 3, H) or (B, S, 3*D) + // key (K) : None + // value (V) : None + // bias (Q/K/V) : None or (D + D + D_v) + + if (query.dims.length !== 3 && query.dims.length !== 5) { + throw new Error('Input query is expected to have 3 or 5 dimensions'); + } + + const dmmhaPacking = false; + const batchSize = query.dims[0]; + const sequenceLength = query.dims[1]; + const hiddenSize = query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : + attributes.numHeads * query.dims[4]; + let kvSequenceLength = sequenceLength; + + let pastSequenceLength = 0; + let maxSequenceLength = 0; + const headSize = Math.floor(hiddenSize / attributes.numHeads); + if (pastKey && pastValue) { + if (pastKey.dims.length !== 4) { + throw new Error('Input "past_key" is expected to have 4 dimensions'); + } + if (pastValue.dims.length !== 4) { + throw new Error('Input "past_value" is expected to have 4 dimensions'); + } + pastSequenceLength = pastKey.dims[2]; + maxSequenceLength = pastKey.dims[2]; + } else if (pastKey || pastValue) { + throw new Error('Input "past_key" and "past_value" shall be both present or both absent'); + } + + let qkvFormat: AttentionQkvFormat; + if (key) { + if (query.dims.length !== 3) { + throw new Error('Input "query" is expected to have 3 dimensions when key is given'); + } + if (key.dims.length < 3 || key.dims.length > 5) { + throw new Error('Input "key" is expected to have 3, 4, or 5 dimensions'); + } + if (query.dims[0] !== key.dims[0]) { + throw new Error('Input "query" and "key" shall have same dim 0 (batch size)'); + } + + if (key.dims.length === 3) { + if (key.dims[2] !== query.dims[2]) { + throw new Error('Input "query" and "key" shall have same dim 2 (hidden_size)'); + } + qkvFormat = AttentionQkvFormat.qkvBSNH; + kvSequenceLength = key.dims[1]; + } else if (key.dims.length === 5) { + if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) { + throw new Error('Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv'); + } + if (value) { + throw new Error('Expect "value" be none when "key" has packed kv format.'); + } + qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H; + kvSequenceLength = key.dims[1]; + } else { // key_dims.size() == 4 (cross-attention with past_key) + if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) { + throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); + } + + qkvFormat = AttentionQkvFormat.unknown; + kvSequenceLength = key.dims[2]; + } + } else { // packed QKV + if (query.dims.length !== 3 && query.dims.length !== 5) { + throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty'); + } + if (query.dims.length === 5 && (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3)) { + throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv'); + } + + qkvFormat = AttentionQkvFormat.qkvBSN3H; + } + + if (bias) { + if (bias.dims.length !== 1) { + throw new Error('Input "bias" is expected to have 1 dimension'); + } + + if (value) { + if (query.dims.length === 5 && query.dims[3] === 2) { + throw new Error('bias is not allowed for packed kv.'); + } + } + } + + let maskType: AttentionMaskType = AttentionMaskType.none; + if (keyPaddingMask) { + maskType = AttentionMaskType.maskUnknown; + const maskDims = keyPaddingMask.dims; + if (maskDims.length === 1) { + if (maskDims[0] === batchSize) { + maskType = AttentionMaskType.mask1dKeySeqLen; + } else if (maskDims[0] === 3 * batchSize + 2) { + maskType = AttentionMaskType.mask1DKeySeqLenStart; + } + } else if (maskDims.length === 2 && maskDims[0] === batchSize && maskDims[1] === kvSequenceLength) { + maskType = AttentionMaskType.mask2dKeyPadding; + } + if (maskType === AttentionMaskType.maskUnknown) { + throw new Error('Input "key_padding_mask" shape shall be (batch_size) or (batch_size, kv_sequence_length)'); + } + throw new Error('Mask not supported'); + } + + let passPastInKv = false; + let vHiddenSize = hiddenSize; + if (value) { + if (value.dims.length !== 3 && value.dims.length !== 4) { + throw new Error('Input "value" is expected to have 3 or 4 dimensions'); + } + + if (query.dims[0] !== value.dims[0]) { + throw new Error('Input "query" and "value" shall have same dim 0 (batch_size)'); + } + + if (value.dims.length === 3) { + if (kvSequenceLength !== value.dims[1]) { + throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)'); + } + vHiddenSize = value.dims[2]; + } else { + if (kvSequenceLength !== value.dims[2]) { + throw new Error('Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)'); + } + vHiddenSize = value.dims[1] * value.dims[3]; + passPastInKv = true; + } + } + + const totalSequenceLength = pastSequenceLength + kvSequenceLength; + const broadcastResPosBias = false; + // if (extraAddQk) { + // if (extraAddQk.dims[0] === 1) { + // broadcastResPosBias = true; + // } + // } + + if (keyPaddingMask) { + throw new Error('Key padding mask is not supported'); + } + if (relativePositionBias) { + throw new Error('extraAddQk is not supported'); + } + if (pastKey) { + throw new Error('pastKey is not supported'); + } + if (pastValue) { + throw new Error('pastValue is not supported'); + } + + return { + batchSize, + sequenceLength, + pastSequenceLength, + kvSequenceLength, + totalSequenceLength, + maxSequenceLength, + inputHiddenSize: 0, + hiddenSize, + vHiddenSize, + headSize, + vHeadSize: Math.floor(vHiddenSize / attributes.numHeads), + numHeads: attributes.numHeads, + isUnidirectional: false, + pastPresentShareBuffer: false, + maskFilterValue: attributes.maskFilterValue, + maskType, + scale: attributes.scale, + broadcastResPosBias, + passPastInKv, + qkvFormat, + }; +}; + + +export const parseMultiHeadAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => + createAttributeWithCacheKey({...attributes}); + +const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [0, 2, 1, 3]}); + +const addBiasTranspose = + (context: ComputeContext, qkv: TensorView, bias: TensorView, batchSize: number, sequenceLength: number, + hiddenSize: number, biasOffset: number) => { + const outputShape = [batchSize, sequenceLength, hiddenSize]; + const outputSize = ShapeUtil.size(outputShape); + + const dataType = tensorTypeToWsglStorageType(qkv.dataType); + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const biasOffset = ${biasOffset}u; + const hiddenSize = ${hiddenSize}u; + + @group(0) @binding(0) var qkv: array<${dataType}>; + @group(0) @binding(1) var bias: array<${dataType}>; + @group(0) @binding(2) var qkv_with_bias: array<${dataType}>; + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + let biasOffsetIdx = (global_idx % hiddenSize) + biasOffset; + + qkv_with_bias[global_idx] = qkv[global_idx] + bias[biasOffsetIdx]; + }`; + + return context.compute( + { + name: 'MultiHeadAttentionAddBias', + shaderCache: {hint: JSON.stringify({batchSize, sequenceLength, hiddenSize, biasOffset})}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: qkv.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + }), + getShaderSource, + }, + {inputs: [qkv, bias], outputs: [-1]})[0]; + }; + +const maybeTransposeToBNSHAndAddBias = + (context: ComputeContext, batchSize: number, numHeads: number, sequenceLength: number, headSize: number, + input: TensorView, bias?: TensorView, biasOffset?: number) => { + // const newDims = []; + + let reshapedInput = input; + if (!bias) { + if (input.dims.length === 3) { + reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]); + } + return context.compute( + createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), + {inputs: [reshapedInput], outputs: [-1]})[0]; + } else { + if (sequenceLength === 1) { + throw new Error('AddBiasReshape is not implemented. Please export your model with packed QKV or KV'); + } else { + reshapedInput = + addBiasTranspose(context, input, bias, batchSize, sequenceLength, numHeads * headSize, biasOffset!); + reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]); + return context.compute( + createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), + {inputs: [reshapedInput], outputs: [-1]})[0]; + } + } + }; + +export const multiHeadAttention = (context: ComputeContext, attributes: AttentionAttrs): void => { + const params = validateInputs(context.inputs, attributes); + + if (context.inputs[0].dims.length === 5) { + throw new Error('Packed QKV is not implemented'); + } + + if (context.inputs[1]?.dims.length === 5) { + throw new Error('Packed KV is not implemented'); + } + + // applyAttention expects BNSH inputs + const kvBNSH = context.inputs[1] && context.inputs[2] && context.inputs[1].dims.length === 4 && + context.inputs[2].dims.length === 4; + + const Q = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, context.inputs[0], + context.inputs[3], 0); + + if (kvBNSH) { + return applyAttention( + context, Q, context.inputs[1], context.inputs[2], context.inputs[4], undefined, undefined, undefined, + context.inputs[5], params, attributes); + } + + const K = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.kvSequenceLength, params.headSize, context.inputs[1], + context.inputs[3], params.hiddenSize); + + const V = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.kvSequenceLength, params.vHeadSize, context.inputs[2], + context.inputs[3], 2 * params.hiddenSize); + + applyAttention( + context, Q, K, V, context.inputs[4], undefined, context.inputs[6], context.inputs[7], context.inputs[5], params, + attributes); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index 180dab92a453a..18859e253aa02 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -36,8 +36,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => { }; const getPadConstant = - (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], - inputStrides: readonly number[], pads: number[], dataType: string, constantValue: number): string => { + (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[], + dataType: string, constantValue: number): string => { const inputRank = inputDims.length; let block = ''; @@ -66,8 +66,7 @@ const getPadConstant = }; const getPadReflect = - (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], - inputStrides: readonly number[], pads: number[]): string => { + (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => { const inputRank = inputDims.length; let block = ''; @@ -97,8 +96,7 @@ const getPadReflect = }; const getPadEdge = - (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], - inputStrides: readonly number[], pads: number[]): string => { + (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => { const inputRank = inputDims.length; let block = ''; @@ -124,8 +122,7 @@ const getPadEdge = }; const getPadWrap = - (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], - inputStrides: readonly number[], pads: number[]): string => { + (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => { const inputRank = inputDims.length; let block = ''; @@ -151,18 +148,17 @@ const getPadWrap = }; const getPadSnippet = - (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], - inputStrides: readonly number[], attributes: PadAttributes, dataType: string): string => { + (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], attributes: PadAttributes, + dataType: string): string => { switch (attributes.mode) { case 0: - return getPadConstant( - output, outputDims, inputDims, inputStrides, attributes.pads, dataType, attributes.value); + return getPadConstant(output, inputDims, inputStrides, attributes.pads, dataType, attributes.value); case 1: - return getPadReflect(output, outputDims, inputDims, inputStrides, attributes.pads); + return getPadReflect(output, inputDims, inputStrides, attributes.pads); case 2: - return getPadEdge(output, outputDims, inputDims, inputStrides, attributes.pads); + return getPadEdge(output, inputDims, inputStrides, attributes.pads); case 3: - return getPadWrap(output, outputDims, inputDims, inputStrides, attributes.pads); + return getPadWrap(output, inputDims, inputStrides, attributes.pads); default: throw new Error('Invalid mode'); } @@ -179,7 +175,7 @@ const generatePadCode = const output = outputVariable('output', inputs[0].dataType, outputDims); const input = inputVariable('x', inputs[0].dataType, inputDims); - const padSnippet = getPadSnippet(output, outputDims, inputDims, inputStrides, attributes, dataType); + const padSnippet = getPadSnippet(output, inputDims, inputStrides, attributes, dataType); const padCode = ` ${shaderHelper.declareVariables(input, output)} ${shaderHelper.mainStart()} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts index 54a7414360c4e..1365d1e9a12a4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts @@ -199,7 +199,7 @@ const reduceCommon = let updatedAxes = updatedAttributes.axes; if (updatedAxes.length === 0 && !updatedAttributes.noopWithEmptyAxes) { - updatedAxes = context.inputs[0].dims.map((s, i) => i); + updatedAxes = context.inputs[0].dims.map((_dim, i) => i); } const normalizeAxes = ShapeUtil.normalizeAxes(updatedAxes, context.inputs[0].dims.length); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index 07cfefb8f191b..9869561a36251 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -218,32 +218,30 @@ const initOutputShape = return outputShape; }; -const adjustOutputShape = - (inputShape: readonly number[], outputShape: readonly number[], scales: number[], attributes: ResizeAttributes): - number[] => { - const scaleInPolicy = (() => { - switch (attributes.keepAspectRatioPolicy) { - case 'not_larger': - return attributes.axes.length > 0 ? Math.min(...attributes.axes.map(i => scales[i]), Number.MAX_VALUE) : - Math.min(...scales, Number.MAX_VALUE); - case 'not_smaller': - return attributes.axes.length > 0 ? Math.max(...attributes.axes.map(i => scales[i]), Number.MIN_VALUE) : - Math.max(...scales, Number.MIN_VALUE); - default: - throw new Error(`Keep aspect ratio policy ${attributes.keepAspectRatioPolicy} is not supported`); - } - })(); - scales.fill(1.0, 0, scales.length); - const adjustedOutputShape = inputShape.slice(); - if (attributes.axes.length > 0) { - attributes.axes.forEach((v) => scales[v] = scaleInPolicy); - attributes.axes.forEach((v) => adjustedOutputShape[v] = Math.round(inputShape[v] * scales[v])); - } else { - scales.fill(scaleInPolicy, 0, scales.length); - adjustedOutputShape.forEach((v, i) => adjustedOutputShape[i] = Math.round(v * scales[i])); - } - return adjustedOutputShape; - }; +const adjustOutputShape = (inputShape: readonly number[], scales: number[], attributes: ResizeAttributes): number[] => { + const scaleInPolicy = (() => { + switch (attributes.keepAspectRatioPolicy) { + case 'not_larger': + return attributes.axes.length > 0 ? Math.min(...attributes.axes.map(i => scales[i]), Number.MAX_VALUE) : + Math.min(...scales, Number.MAX_VALUE); + case 'not_smaller': + return attributes.axes.length > 0 ? Math.max(...attributes.axes.map(i => scales[i]), Number.MIN_VALUE) : + Math.max(...scales, Number.MIN_VALUE); + default: + throw new Error(`Keep aspect ratio policy ${attributes.keepAspectRatioPolicy} is not supported`); + } + })(); + scales.fill(1.0, 0, scales.length); + const adjustedOutputShape = inputShape.slice(); + if (attributes.axes.length > 0) { + attributes.axes.forEach((v) => scales[v] = scaleInPolicy); + attributes.axes.forEach((v) => adjustedOutputShape[v] = Math.round(inputShape[v] * scales[v])); + } else { + scales.fill(scaleInPolicy, 0, scales.length); + adjustedOutputShape.forEach((v, i) => adjustedOutputShape[i] = Math.round(v * scales[i])); + } + return adjustedOutputShape; +}; const calculateOriginalIndicesFromOutputIndices = (output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[], @@ -314,8 +312,8 @@ const checkInputIndices = (input: IndicesHelper, inputShape: readonly number[]): }`; const bilinearInterpolation = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], - scales: readonly number[], useExtrapolation: boolean, extrapolationValue: number): string => { + (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], scales: readonly number[], + useExtrapolation: boolean, extrapolationValue: number): string => { const [batchIdx, heightIdx, widthIdx, channelIdx] = inputShape.length === 2 ? [-1, 0, 1, -1] : (scales[1] === 1.0 ? [0, 2, 3, 1] : [0, 1, 2, 3]); return ` @@ -445,7 +443,7 @@ const createResizeProgramInfo = if (scalesInput.length === 0) { scales = inputShape.map((value, index) => value === 0 ? 1.0 : outputShape[index] / value); if (attributes.keepAspectRatioPolicy !== 'stretch') { - outputShape = adjustOutputShape(inputShape, outputShape, scales, attributes); + outputShape = adjustOutputShape(inputShape, scales, attributes); } } const output = outputVariable('output', inputTensor.dataType, outputShape); @@ -471,7 +469,7 @@ const createResizeProgramInfo = ${calculateOriginalIndicesFromOutputIndices(output, inputShape, outputShape, scales, roi)}; ${ bilinearInterpolation( - input, output, inputShape, outputShape, scales, useExtrapolation, attributes.extrapolationValue)}; + input, output, inputShape, scales, useExtrapolation, attributes.extrapolationValue)}; `; case 'cubic': return ` diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index d607351f69b74..7458579bf4340 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -5,9 +5,9 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, TensorInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; export interface SliceAttributes extends AttributeWithCacheKey { readonly starts: number[]; @@ -77,17 +77,26 @@ const fixStartEndValues = }; const calculateInputIndicesImpl = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[]): - string => `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { + (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], + enableInputShapeUniforms: boolean): string => + `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { var inputIndices: ${input.type.indices}; var carry = 0u; for (var i = ${inputShape.length}; i >= 0; i--) { + let input_shape_i = ${ + enableInputShapeUniforms ? `uniforms.input_shape${inputShape.length > 1 ? '[i]' : ''}` : 'inputShape[i]'}; + let steps_i = ${ + enableInputShapeUniforms ? `uniforms.steps${inputShape.length > 1 ? '[i]' : ''}` : 'steps[i]'}; + let signs_i = ${ + enableInputShapeUniforms ? `uniforms.signs${inputShape.length > 1 ? '[i]' : ''}` : 'signs[i]'}; + let starts_i = ${ + enableInputShapeUniforms ? `uniforms.starts${inputShape.length > 1 ? '[i]' : ''}` : 'starts[i]'}; var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; - var inputIndex = outputIndex * steps[i] + starts[i] + carry; - carry = inputIndex / inputShape[i]; - inputIndex = inputIndex % inputShape[i]; - if (signs[i] < 0) { - inputIndex = inputShape[i] - inputIndex - 1u + starts[i]; + var inputIndex = outputIndex * steps_i + starts_i + carry; + carry = inputIndex / input_shape_i; + inputIndex = inputIndex % input_shape_i; + if (signs_i < 0) { + inputIndex = input_shape_i - inputIndex - 1u + starts_i; } ${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'} = inputIndex; } @@ -110,6 +119,10 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice const ends = attributes.ends.map((end, i) => fixStartEndValues(end, i, inputShape, axes, steps)); + if (axes.length !== starts.length || axes.length !== ends.length) { + throw new Error('start, ends and axes should have the same number of elements'); + } + if (axes.length !== inputShape.length) { for (let i = 0; i < inputShape.length; ++i) { if (!axes.includes(i)) { @@ -131,40 +144,66 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice array[i] = -step; } }); + // Output rank is expected to be less than or equal to the input rank. + const enableShapeUniforms = enableShapesUniforms(inputs[0].dims.length); + const inputShapeOrRank = enableShapeUniforms ? inputs[0].dims.length : inputs[0].dims; const outputShape = inputShape.slice(0); axes.forEach((axis, _) => { outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]); }); + const outputShapeOrRank = enableShapeUniforms ? outputShape.length : outputShape; const outputTensorInfo: TensorInfo = {dims: outputShape, dataType: inputs[0].dataType}; - const output = outputVariable('output', inputs[0].dataType, outputShape); - const input = inputVariable('input', inputs[0].dataType, inputShape); + const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank); + const input = inputVariable('input', inputs[0].dataType, inputShapeOrRank); const outputSize = ShapeUtil.size(outputShape); + const programUniforms: ProgramUniform[] = []; + const uniforms: UniformsArrayType = []; + if (enableShapeUniforms) { + uniforms.push({name: 'starts', type: starts.length > 1 ? `vec${starts.length}` : 'u32'}); + uniforms.push({name: 'signs', type: signs.length > 1 ? `vec${signs.length}` : 'i32'}); + uniforms.push({name: 'steps', type: steps.length > 1 ? `vec${steps.length}` : 'u32'}); + programUniforms.push({type: 'uint32', data: starts}); + programUniforms.push({type: 'int32', data: signs}); + programUniforms.push({type: 'uint32', data: steps}); + } + uniforms.push({name: 'outputSize', type: 'u32'}); + programUniforms.push({type: 'uint32', data: outputSize}); + if (enableShapeUniforms) { + programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); + programUniforms.push(...createTensorShapeVariables(outputShape)); + } const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(input, output)} - const signs = array(${signs.map(i => `${i}i`).join(',')}); - const starts = array(${starts.map(i => `${i}u`).join(',')}); - const ends = array(${ends.map(i => `${i}u`).join(',')}); - const steps = array(${steps.map(i => `${i}u`).join(',')}); - const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); - - ${calculateInputIndicesImpl(input, output, inputShape, outputShape)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)} + ${enableShapeUniforms ? '' : [ + `const signs = array(${signs.map(i => `${i}i`).join(',')});`, + `const starts = array(${starts.map(i => `${i}u`).join(',')});`, + `const steps = array(${steps.map(i => `${i}u`).join(',')});`, + `const inputShape = array(${inputShape.map(i => `${i}u`).join(',')});` + ].join('\n')} + + ${calculateInputIndicesImpl(input, output, inputShape, outputShape, enableShapeUniforms)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} let outputIndices = ${output.offsetToIndices('global_idx')}; let inputIndices = calculateInputIndices(outputIndices); ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))} }`; return { name: 'Slice', - shaderCache: {hint: `${attributes.cacheKey}|${inputs[4]?.dims ?? ''}`}, + shaderCache: { + hint: enableShapeUniforms ? `${signs.length}_${starts.length}_${steps.length}` : + `${attributes.cacheKey} | ${inputs[4]?.dims ?? ''}`, + inputDependencies: [enableShapeUniforms ? 'rank' : 'dims'] + }, getShaderSource, getRunData: () => ({ outputs: [outputTensorInfo], dispatchGroup: {x: Math.ceil(inputSize / 64 /* workgroup size */)}, + programUniforms }) }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 4238449f9246f..119609e06f5a3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -124,7 +124,14 @@ export interface ClipAttributes extends AttributeWithCacheKey { readonly max: number; } -export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => { +const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => { + const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP; + const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP; + return createAttributeWithCacheKey({min, max}); +}; + +export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => { + const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs); const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); context.compute( createElementwiseProgramInfo( @@ -135,16 +142,6 @@ export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): vo attributes.cacheKey), {inputs: [0]}); }; -const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => { - const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP; - const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP; - return createAttributeWithCacheKey({min, max}); -}; - -export const clip = (context: ComputeContext): void => { - const attributes = generateClipAttributesFromInputs(context.inputs); - clipV10(context, attributes); -}; export const ceil = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Ceil', 'ceil')); diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 02691b2b6e626..0b0a545f46481 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -68,7 +68,7 @@ export class ProgramManager { this.backend.querySetCount * 8, GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST); this.backend.endComputePass(); - this.backend.getCommandEncoder().resolveQuerySet(this.backend.querySet, 0, 2, this.backend.queryData.buffer, 0); + this.backend.getCommandEncoder().resolveQuerySet(this.backend.querySet!, 0, 2, this.backend.queryData.buffer, 0); this.backend.getCommandEncoder().copyBufferToBuffer( this.backend.queryData.buffer, 0, syncData.buffer, 0, this.backend.querySetCount * 8); this.backend.flush(); @@ -77,7 +77,7 @@ export class ProgramManager { const kernelInfo = this.backend.kernels.get(kernelId)!; const kernelName = `[${kernelInfo[0]}] ${kernelInfo[1]}`; - syncData.buffer.mapAsync(GPUMapMode.READ).then(() => { + void syncData.buffer.mapAsync(GPUMapMode.READ).then(() => { const mappedData = new BigUint64Array(syncData.buffer.getMappedRange()); const startTimeU64 = mappedData[0]; const endTimeU64 = mappedData[1]; diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts index 1f4595818e5c0..1cb6d9e391e4f 100644 --- a/js/web/lib/wasm/proxy-worker/main.ts +++ b/js/web/lib/wasm/proxy-worker/main.ts @@ -3,7 +3,39 @@ /// -import {OrtWasmMessage} from '../proxy-messages'; +// +// * type hack for "HTMLImageElement" +// +// in typescript, the type of "HTMLImageElement" is defined in lib.dom.d.ts, which is conflict with lib.webworker.d.ts. +// when we use webworker, the lib.webworker.d.ts will be used, which does not have HTMLImageElement defined. +// +// we will get the following errors complaining that HTMLImageElement is not defined: +// +// ==================================================================================================================== +// +// ../common/dist/cjs/tensor-factory.d.ts:187:29 - error TS2552: Cannot find name 'HTMLImageElement'. Did you mean +// 'HTMLLIElement'? +// +// 187 fromImage(imageElement: HTMLImageElement, options?: TensorFromImageElementOptions): +// Promise | TypedTensor<'uint8'>>; +// ~~~~~~~~~~~~~~~~ +// +// node_modules/@webgpu/types/dist/index.d.ts:83:7 - error TS2552: Cannot find name 'HTMLImageElement'. Did you mean +// 'HTMLLIElement'? +// +// 83 | HTMLImageElement +// ~~~~~~~~~~~~~~~~ +// +// ==================================================================================================================== +// +// `HTMLImageElement` is only used in type declaration and not in real code. So we define it as `unknown` here to +// bypass the type check. +// +declare global { + type HTMLImageElement = unknown; +} + +import {OrtWasmMessage, SerializableTensorMetadata} from '../proxy-messages'; import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl'; import {initializeWebAssembly} from '../wasm-factory'; @@ -11,7 +43,7 @@ self.onmessage = (ev: MessageEvent): void => { switch (ev.data.type) { case 'init-wasm': try { - initializeWebAssembly(ev.data.in) + initializeWebAssembly(ev.data.in!) .then( () => postMessage({type: 'init-wasm'} as OrtWasmMessage), err => postMessage({type: 'init-wasm', err} as OrtWasmMessage)); @@ -21,10 +53,10 @@ self.onmessage = (ev: MessageEvent): void => { break; case 'init-ort': try { - initRuntime(ev.data.in).then(() => postMessage({type: 'init-ort'} as OrtWasmMessage), err => postMessage({ - type: 'init-ort', - err - } as OrtWasmMessage)); + initRuntime(ev.data.in!).then(() => postMessage({type: 'init-ort'} as OrtWasmMessage), err => postMessage({ + type: 'init-ort', + err + } as OrtWasmMessage)); } catch (err) { postMessage({type: 'init-ort', err} as OrtWasmMessage); } @@ -58,8 +90,7 @@ self.onmessage = (ev: MessageEvent): void => { break; case 'release': try { - const handler = ev.data.in!; - releaseSession(handler); + releaseSession(ev.data.in!); postMessage({type: 'release'} as OrtWasmMessage); } catch (err) { postMessage({type: 'release', err} as OrtWasmMessage); @@ -68,10 +99,16 @@ self.onmessage = (ev: MessageEvent): void => { case 'run': try { const {sessionId, inputIndices, inputs, outputIndices, options} = ev.data.in!; - run(sessionId, inputIndices, inputs, outputIndices, options) + run(sessionId, inputIndices, inputs, outputIndices, new Array(outputIndices.length).fill(null), options) .then( outputs => { - postMessage({type: 'run', out: outputs} as OrtWasmMessage, extractTransferableBuffers(outputs)); + if (outputs.some(o => o[3] !== 'cpu')) { + postMessage({type: 'run', err: 'Proxy does not support non-cpu tensor location.'}); + } else { + postMessage( + {type: 'run', out: outputs} as OrtWasmMessage, + extractTransferableBuffers(outputs as SerializableTensorMetadata[])); + } }, err => { postMessage({type: 'run', err} as OrtWasmMessage); diff --git a/js/web/lib/wasm/proxy-worker/tsconfig.json b/js/web/lib/wasm/proxy-worker/tsconfig.json new file mode 100644 index 0000000000000..ec1044612a569 --- /dev/null +++ b/js/web/lib/wasm/proxy-worker/tsconfig.json @@ -0,0 +1,8 @@ +{ + "extends": "../../../tsconfig.json", + "compilerOptions": { + "lib": ["WebWorker"] + }, + "include": ["main.ts", "../../build-def.d.ts"], + "exclude": [] +} diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 02ff229cc4954..45ea48a2df209 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -75,6 +75,19 @@ const setExecutionProviders = checkLastError(`Can't set a session config entry: 'deviceType' - ${webnnOptions.deviceType}.`); } } + if (webnnOptions?.numThreads) { + let numThreads = webnnOptions.numThreads; + // Just ignore invalid webnnOptions.numThreads. + if (typeof numThreads != 'number' || !Number.isInteger(numThreads) || numThreads < 0) { + numThreads = 0; + } + const keyDataOffset = allocWasmString('numThreads', allocs); + const valueDataOffset = allocWasmString(numThreads.toString(), allocs); + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== + 0) { + checkLastError(`Can't set a session config entry: 'numThreads' - ${webnnOptions.numThreads}.`); + } + } if (webnnOptions?.powerPreference) { const keyDataOffset = allocWasmString('powerPreference', allocs); const valueDataOffset = allocWasmString(webnnOptions.powerPreference, allocs); diff --git a/js/web/package-lock.json b/js/web/package-lock.json index 9567bc172c9ed..890c5a0f34765 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -25,7 +25,7 @@ "@types/minimatch": "^5.1.2", "@types/minimist": "^1.2.2", "@types/platform": "^1.3.4", - "@webgpu/types": "^0.1.30", + "@webgpu/types": "^0.1.38", "base64-js": "^1.5.1", "chai": "^4.3.7", "electron": "^23.1.2", @@ -323,9 +323,9 @@ } }, "node_modules/@webgpu/types": { - "version": "0.1.30", - "resolved": "https://registry.npmjs.org/@webgpu/types/-/types-0.1.30.tgz", - "integrity": "sha512-9AXJSmL3MzY8ZL//JjudA//q+2kBRGhLBFpkdGksWIuxrMy81nFrCzj2Am+mbh8WoU6rXmv7cY5E3rdlyru2Qg==", + "version": "0.1.38", + "resolved": "https://registry.npmjs.org/@webgpu/types/-/types-0.1.38.tgz", + "integrity": "sha512-7LrhVKz2PRh+DD7+S+PVaFd5HxaWQvoMqBbsV9fNJO1pjUs1P8bM2vQVNfk+3URTqbuTI7gkXi0rfsN0IadoBA==", "dev": true }, "node_modules/accepts": { @@ -3767,9 +3767,9 @@ } }, "@webgpu/types": { - "version": "0.1.30", - "resolved": "https://registry.npmjs.org/@webgpu/types/-/types-0.1.30.tgz", - "integrity": "sha512-9AXJSmL3MzY8ZL//JjudA//q+2kBRGhLBFpkdGksWIuxrMy81nFrCzj2Am+mbh8WoU6rXmv7cY5E3rdlyru2Qg==", + "version": "0.1.38", + "resolved": "https://registry.npmjs.org/@webgpu/types/-/types-0.1.38.tgz", + "integrity": "sha512-7LrhVKz2PRh+DD7+S+PVaFd5HxaWQvoMqBbsV9fNJO1pjUs1P8bM2vQVNfk+3URTqbuTI7gkXi0rfsN0IadoBA==", "dev": true }, "accepts": { diff --git a/js/web/package.json b/js/web/package.json index 7271fed99d709..9b4531d7766fe 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -24,6 +24,7 @@ "build:doc": "node ./script/generate-webgl-operator-md && node ./script/generate-webgpu-operator-md", "pull:wasm": "node ./script/pull-prebuilt-wasm-artifacts", "test:e2e": "node ./test/e2e/run", + "prebuild": "tsc -p . --noEmit && tsc -p lib/wasm/proxy-worker --noEmit", "build": "node ./script/build", "test": "tsc --build ../scripts && node ../scripts/prepare-onnx-node-tests && node ./script/test-runner-cli", "prepack": "node ./script/build && node ./script/prepack" @@ -42,7 +43,7 @@ "@types/minimatch": "^5.1.2", "@types/minimist": "^1.2.2", "@types/platform": "^1.3.4", - "@webgpu/types": "^0.1.30", + "@webgpu/types": "^0.1.38", "base64-js": "^1.5.1", "chai": "^4.3.7", "electron": "^23.1.2", diff --git a/js/web/script/generate-webgpu-operator-md.ts b/js/web/script/generate-webgpu-operator-md.ts index 7408f17004f5e..eab8175a941bd 100644 --- a/js/web/script/generate-webgpu-operator-md.ts +++ b/js/web/script/generate-webgpu-operator-md.ts @@ -16,6 +16,8 @@ const COMMENTS: Record = { 'Reshape': 'no GPU kernel', 'Shape': 'no GPU kernel; an ORT warning is generated - need to fix', 'Resize': 'CoordinateTransformMode align_corners is not supported with downsampling', + 'Attention': 'need implementing mask and past/present', + 'MultiHeadAttention': 'need implementing mask and past/present', }; /* eslint-disable max-len */ diff --git a/js/web/test/data/ops/attention.jsonc b/js/web/test/data/ops/attention.jsonc new file mode 100644 index 0000000000000..bd4483027cc25 --- /dev/null +++ b/js/web/test/data/ops/attention.jsonc @@ -0,0 +1,557 @@ +[ + { + "name": "Attention Basic", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [4, 3], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [213, 213], + "dims": [1, 2, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic Batch 2 with 2 heads", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16 + ], + "dims": [2, 2, 8], + "type": "float32" + }, + { + "data": [ + 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 + ], + "dims": [8, 6], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [6], + "type": "float32" + } + ], + "outputs": [ + { + "data": [320, 321, 320, 321, 320, 321, 320, 321], + "dims": [2, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863], + "dims": [1, 3, 2], + "type": "float32" + }, + { + "data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1.1103, -1.6898, -0.989], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-1.328187108039856, -1.297916054725647, -0.8599594831466675], + "dims": [1, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic one head, batch 2", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3, 2], + "type": "float32" + }, + { + "data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1.1103, -1.6898, -0.989], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987 + ], + "dims": [2, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 2 head, batch 1", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3, 2], + "type": "float32" + }, + { + "data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643], + "dims": [2, 6], + "type": "float32" + }, + { + "data": [1.1103, -1.6898, -0.989, -0.989, 1.1103, -1.6898], + "dims": [6], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.8701779842376709, -2.6158859729766846, 0.8710794448852539, -2.5763747692108154, 0.9005484580993652, + -2.182751178741455, 2.1661579608917236, -2.1045265197753906, 1.6716957092285156, -1.797281265258789, + 1.7134947776794434, -1.765358328819275 + ], + "dims": [2, 3, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 2", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987 + ], + "dims": [2, 3, 5], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236 + ], + "dims": [5, 15], + "type": "float32" + }, + { + "data": [ + 1.1103, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, -1.6898, -0.989, -1.9029953479766846, 0.8710794448852539, + -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, 1.7134947776794434 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.6956915855407715, -2.8863370418548584, 1.3899128437042236, 1.6789076328277588, -1.4083852767944336, + -1.7009180784225464, -3.1053788661956787, 3.5959298610687256, 1.1027096509933472, -0.009643087163567543, + -1.694351315498352, -2.9284396171569824, 1.734721302986145, 2.0606398582458496, -0.2571452260017395, + 3.671973943710327, -5.285338401794434, -6.833454132080078, 1.7506506443023682, -2.262148380279541, + 2.5110034942626953, 1.440049171447754, -0.9423203468322754, 1.7506506443023682, -1.86212158203125, + -0.5036701560020447, -5.732386589050293, -1.5674757957458496, 1.7506510019302368, -2.264472246170044 + ], + "dims": [2, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 1", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846 + ], + "dims": [1, 3, 5], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236 + ], + "dims": [5, 15], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [1, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987 + ], + "dims": [3, 3, 5], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236 + ], + "dims": [5, 15], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326, + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326, + 3.7965505123138428, -2.3799397945404053, -3.9530906677246094, 0.5844926834106445, -2.9756431579589844, + 2.448162794113159, 4.34546422958374, 1.9380426406860352, 0.5870105624198914, -2.7368364334106445, + -0.4769568145275116, 4.255186557769775, -3.9529950618743896, 0.6987408995628357, -2.9756433963775635 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, + 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, + 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987 + ], + "dims": [3, 3, 10], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, + 0.2303, 0.4617, 1.44, -2.22, 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, + -1.9054111242294312, 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236 + ], + "dims": [10, 15], + "type": "float32" + }, + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -8.01101303100586, -5.782258987426758, 6.016238689422607, 0.26747000217437744, -6.992541313171387, + -8.011263847351074, -5.782248020172119, 5.366001129150391, 0.26747000217437744, -6.99449348449707, + -8.011263847351074, -5.782265663146973, 6.016238689422607, 0.26747000217437744, -6.992537021636963, + -6.102723598480225, -7.28973388671875, -4.578637599945068, 7.2203369140625, -6.028444766998291, + -6.102705478668213, -7.2897748947143555, -3.7882626056671143, 5.393260478973389, -5.754333972930908, + -1.3616288900375366, -7.289827823638916, -6.341128349304199, 6.329389572143555, -5.751791954040527, + -2.3945987224578857, -14.532954216003418, 3.969801902770996, 12.744998931884766, -11.1966552734375, + -2.4002532958984375, -14.538958549499512, -6.684961318969727, 12.476543426513672, -9.24352741241455, + -4.787771701812744, -8.640848159790039, 3.969801902770996, -0.6471102833747864, -11.1966552734375 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 1 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, + 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, + 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987 + ], + "dims": [3, 3, 10], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, + 0.2303, 0.4617, 1.44, -2.22, 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, + -1.9054111242294312, 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236 + ], + "dims": [10, 15], + "type": "float32" + }, + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + 1.3541864156723022, -7.813620090484619, -6.758509635925293, 7.597365856170654, -13.926229476928711, + -1.322464108467102, -7.297357559204102, -0.05962071940302849, 6.347561836242676, -5.869992256164551, + -1.3616288900375366, -7.28973388671875, 0.0386197566986084, 6.329389572143555, -5.751791954040527, + -2.400698661804199, -14.538958549499512, -7.898950576782227, 12.744998931884766, -11.1966552734375, + -2.400698661804199, -14.538958549499512, -7.898950576782227, 12.744998931884766, -11.1966552734375, + 1.021930456161499, -2.373898983001709, 3.8501391410827637, -0.6108309626579285, -9.256340980529785 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/multi-head-attention.jsonc b/js/web/test/data/ops/multi-head-attention.jsonc new file mode 100644 index 0000000000000..05687bd482e24 --- /dev/null +++ b/js/web/test/data/ops/multi-head-attention.jsonc @@ -0,0 +1,194 @@ +[ + { + "name": "MultiHeadAttention Basic, one head", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 4.973228454589844, 5.973228454589844, 6.973228454589844, 7.973228454589844, 4.999990940093994, + 5.999990940093994, 6.999990940093994, 7.999990940093994 + ], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 4.571832656860352, 5.571832656860352, 6.971858501434326, 7.971858501434326, 4.998325824737549, + 5.998325824737549, 6.999900817871094, 7.999900817871094 + ], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic with bias", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4], + "dims": [12], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 5.943336009979248, 7.94333553314209, 9.999799728393555, 11.999798774719238, 5.9997992515563965, + 7.9997992515563965, 10, 11.999999046325684 + ], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention two heads", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 8.99963665008545, 9.99963665008545, 10.99963665008545, 11.999635696411133, 13, 14, 15, 16, 9, 10, 11, 12, + 13, 14, 15, 16 + ], + "dims": [1, 2, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention two heads", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[1]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 1, 8], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 1, 8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 8], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/slice.jsonc b/js/web/test/data/ops/slice.jsonc index 9c90817a80c36..beef154a29932 100644 --- a/js/web/test/data/ops/slice.jsonc +++ b/js/web/test/data/ops/slice.jsonc @@ -21,6 +21,29 @@ } ] }, + { + "name": "Slice float32 with input[0] dim > 4", + "operator": "Slice", + "attributes": [], + "cases": [ + { + "name": "T[1, 1, 1, 1, 5] T[1] T[1] T[1] (float32)", + "inputs": [ + { + "data": [ + 0.3964604139328003, -0.8916832804679871, -1.6578896045684814, 1.960708737373352, 1.181204915046692 + ], + "dims": [1, 1, 1, 1, 5], + "type": "float32" + }, + { "data": [3], "dims": [1], "type": "int64" }, + { "data": [4], "dims": [1], "type": "int64" }, + { "data": [4], "dims": [1], "type": "int64" } + ], + "outputs": [{ "data": [1.960708737373352], "dims": [1, 1, 1, 1, 1], "type": "float32" }] + } + ] + }, { "name": "Slice int32", "operator": "Slice", diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index c80f0b04a9abc..37aa9394c7f96 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1336,6 +1336,7 @@ "add_int32.jsonc", //"and.jsonc", "asin.jsonc", + "attention.jsonc", "bias-add.jsonc", "bias-split-gelu.jsonc", "ceil.jsonc", @@ -1362,6 +1363,7 @@ "matmul-broadcast.jsonc", "mul.jsonc", "mul_int32.jsonc", + "multi-head-attention.jsonc", //"neg.jsonc", "neg-int32.jsonc", "not.jsonc", diff --git a/js/web/test/unittests/backends/webgl/test-conv-new.ts b/js/web/test/unittests/backends/webgl/test-conv-new.ts index 0fddddf58181c..8c186b9b36451 100644 --- a/js/web/test/unittests/backends/webgl/test-conv-new.ts +++ b/js/web/test/unittests/backends/webgl/test-conv-new.ts @@ -14,7 +14,7 @@ import {conv2d} from './test-conv-utils'; function createRandomArray(size: number): Float32Array { const randomTable = [0, 3, 6, 9, 2, 5, 8, 1, 4, 7]; return new Float32Array( - Array.from({length: size}, (v, k) => randomTable[k % 10] * 0.1 + randomTable[Math.trunc(k / 10) % 10] * 0.01)); + Array.from({length: size}, (_v, k) => randomTable[k % 10] * 0.1 + randomTable[Math.trunc(k / 10) % 10] * 0.01)); } interface TestData { inputShape: number[]; diff --git a/js/web/test/unittests/backends/webgl/test-pack-unpack.ts b/js/web/test/unittests/backends/webgl/test-pack-unpack.ts index 0b70144733227..61c21d4b689fb 100644 --- a/js/web/test/unittests/backends/webgl/test-pack-unpack.ts +++ b/js/web/test/unittests/backends/webgl/test-pack-unpack.ts @@ -291,7 +291,7 @@ describe('#UnitTest# - unpack - Tensor unpack', () => { webglInferenceHandler.session.textureManager.glContext.checkError(); const webglTexture = createTextureFromArray( webglInferenceHandler.session.textureManager.glContext, testData.rawData ? testData.rawData : inputData, - gl.RGBA, inputTextureShape[0], inputTextureShape[1]); + inputTextureShape[0], inputTextureShape[1]); webglInferenceHandler.session.textureManager.glContext.checkError(); const packedShape = inputTextureShape; const textureData = { diff --git a/js/web/test/unittests/backends/webgl/test-utils.ts b/js/web/test/unittests/backends/webgl/test-utils.ts index acb3f0002ce2f..092d63cd2ade4 100644 --- a/js/web/test/unittests/backends/webgl/test-utils.ts +++ b/js/web/test/unittests/backends/webgl/test-utils.ts @@ -4,7 +4,7 @@ import {WebGLContext} from '../../../../lib/onnxjs/backends/webgl/webgl-context'; export function createAscendingArray(size: number): Float32Array { - return new Float32Array(Array.from({length: size}, (v, i) => (i + 1))); + return new Float32Array(Array.from({length: size}, (_v, i) => (i + 1))); } // Returns an array by injecting 3 zeros after every element in the input array to be used for creating unpacked @@ -19,7 +19,7 @@ export function generateArrayForUnpackedTexture(input: Float32Array): Float32Arr // create a webgl texture and fill it with the array content export function createTextureFromArray( - glContext: WebGLContext, dataArray: Float32Array, type: GLenum, width: number, height: number): WebGLTexture { + glContext: WebGLContext, dataArray: Float32Array, width: number, height: number): WebGLTexture { const gl = glContext.gl; // create the texture diff --git a/js/web/tsconfig.json b/js/web/tsconfig.json index eb4674fab3ca9..d60d746e9328d 100644 --- a/js/web/tsconfig.json +++ b/js/web/tsconfig.json @@ -1,7 +1,6 @@ { "extends": "../tsconfig.json", "compilerOptions": { - "module": "Node16", "downlevelIteration": true, "declaration": true, "typeRoots": ["./node_modules/@webgpu/types", "./node_modules/@types", "../node_modules/@types"] diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 0ed7d887fc5e5..57219c50f39aa 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -61,7 +61,6 @@ from onnxruntime.capi.onnxruntime_inference_collection import OrtDevice # noqa: F401 from onnxruntime.capi.onnxruntime_inference_collection import OrtValue # noqa: F401 from onnxruntime.capi.onnxruntime_inference_collection import SparseTensor # noqa: F401 -from onnxruntime.capi.training import * # noqa: F403 # TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end try: # noqa: SIM105 diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index b693b58c7c40a..a7f83469a768d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -96,9 +96,9 @@ struct GroupQueryAttentionParameters { int kv_num_heads; int num_splits; // number of splits for splitkv bool is_unidirectional; // causal + int local_window_size; bool kv_share_buffer; - bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor - bool left_padding; // copies last token to last index if true + bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor float scale; AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index 4a266af789250..47f462d75fcc4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -63,6 +63,16 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { const int head_size = parameters.head_size; const int position_ids_format = parameters.position_ids_format; const int half_head_size = head_size / 2; + // Default input tensor shape is [batch, seq_len, hidden_size] + int head_stride = head_size; + int seq_stride = num_heads * head_stride; + int batch_stride = sequence_length * seq_stride; + if (parameters.transposed) { + // Transposed input tensor shape is [batch, num_heads, seq_len, head_size] + seq_stride = head_size; + head_stride = sequence_length * seq_stride; + batch_stride = num_heads * head_stride; + } AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -76,11 +86,10 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { const int s = static_cast((ptr / num_heads) % sequence_length); const int n = static_cast(ptr % num_heads); - const int block_offset = b * sequence_length * num_heads + s * num_heads + n; - const int data_offset = block_offset * head_size; + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - const T* input_data = input_src + data_offset; - T* output_data = output_dest + data_offset; + const T* input_data = input_src + block_offset; + T* output_data = output_dest + block_offset; // Cache is (M, H/2) const int position_id = (position_ids_format == 0) diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h index cf8080800e072..7b2e8289f7b06 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h @@ -18,6 +18,7 @@ struct RotaryParameters { int num_heads; // num_heads = hidden_size / head_size int max_sequence_length; // Sequence length used by cos/sin cache int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) + bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) }; template @@ -33,8 +34,8 @@ Status CheckInputs(const T* input, // Check input const auto& input_dims = input->Shape().GetDims(); - if (input_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 dimensions, got ", + if (input_dims.size() != 3 && input_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 or 4 dimensions, got ", input_dims.size()); } // Check position_ids @@ -63,6 +64,14 @@ Status CheckInputs(const T* input, int batch_size = static_cast(input_dims[0]); int sequence_length = static_cast(input_dims[1]); int hidden_size = static_cast(input_dims[2]); + + bool transposed = false; + if (input_dims.size() == 4) { + // input is [batch, num_heads, seq, head_size] + sequence_length = static_cast(input_dims[2]); + hidden_size = static_cast(input_dims[1]) * static_cast(input_dims[3]); + transposed = true; + } int max_sequence_length = static_cast(cos_cache_dims[0]); int head_size = static_cast(cos_cache_dims[1]) * 2; int num_heads = hidden_size / head_size; @@ -111,6 +120,7 @@ Status CheckInputs(const T* input, output_parameters->num_heads = num_heads; output_parameters->max_sequence_length = max_sequence_length; output_parameters->position_ids_format = position_ids_format; + output_parameters->transposed = transposed; } return Status::OK(); @@ -118,4 +128,4 @@ Status CheckInputs(const T* input, } // namespace rotary_embedding_helper } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index c72d811170a27..320a05bb97dac 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -1,35 +1,38 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/framework/op_kernel.h" +#include "core/mlas/inc/mlas.h" +#include "core/mlas/inc/mlas_qnbit.h" +#include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" -#include "core/mlas/inc/mlas_q4.h" namespace onnxruntime { namespace contrib { class MatMulNBits final : public OpKernel { public: - MatMulNBits(const OpKernelInfo& info) : OpKernel(info) { - ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + MatMulNBits(const OpKernelInfo& info) + : OpKernel(info), + K_{narrow(info.GetAttr("K"))}, + N_{narrow(info.GetAttr("N"))}, + block_size_{narrow(info.GetAttr("block_size"))}, + nbits_{narrow(info.GetAttr("bits"))} { ORT_ENFORCE(nbits_ == 4, - "Only 4b quantization is supported for MatMulNBits op," - " additional bits support is planned."); + "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); } Status Compute(OpKernelContext* context) const override; private: - int64_t K_; - int64_t N_; - int64_t block_size_; - int64_t nbits_; - bool column_wise_quant_{true}; + const size_t K_; + const size_t N_; + const size_t block_size_; + const size_t nbits_; + const bool column_wise_quant_{true}; }; Status MatMulNBits::Compute(OpKernelContext* ctx) const { @@ -45,11 +48,60 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const auto* scales_data = scales->Data(); const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + TensorShape b_shape({static_cast(N_), static_cast(K_)}); + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + + Tensor* y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) + return Status::OK(); + + auto* y_data = y->MutableData(); + + const size_t batch_count = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(false); + + if (MlasIsSQNBitGemmAvailable(nbits_, block_size_)) { + // number of bytes or elements between adjacent matrices + size_t b_data_matrix_stride_in_bytes, b_scale_matrix_stride, b_zero_point_matrix_stride_in_bytes; + MlasBlockwiseQuantizedBufferSizes(static_cast(nbits_), static_cast(block_size_), /* columnwise */ true, + static_cast(K), static_cast(N), + b_data_matrix_stride_in_bytes, b_scale_matrix_stride, + &b_zero_point_matrix_stride_in_bytes); + + const size_t b_matrix_size = K * N; + + InlinedVector data(batch_count); + for (size_t i = 0; i < batch_count; ++i) { + const size_t b_matrix_offset = helper.RightOffsets()[i] / b_matrix_size; + + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].QuantBData = b_data + b_matrix_offset * b_data_matrix_stride_in_bytes; + data[i].QuantBScale = scales_data + b_matrix_offset * b_scale_matrix_stride; + data[i].QuantBZeroPoint = zero_points_data != nullptr + ? zero_points_data + b_matrix_offset * b_zero_point_matrix_stride_in_bytes + : nullptr; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + } + + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, data.data(), thread_pool); + + return Status::OK(); + } + + const size_t ldb = helper.Ldb(true); + AllocatorPtr allocator; - auto status = ctx->GetTempSpaceAllocator(&allocator); - ORT_RETURN_IF_ERROR(status); + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); - // dequantize b, only 4b quantization is supported for now MlasDequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output @@ -67,29 +119,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_); #endif - TensorShape b_shape({N_, K_}); - - MatMulComputeHelper helper; - ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); - - Tensor* y = ctx->Output(0, helper.OutputShape()); - - // Bail out early if the output is going to be empty - if (y->Shape().Size() == 0) - return Status::OK(); - - auto* y_data = y->MutableData(); - - const size_t max_len = helper.OutputOffsets().size(); - const size_t M = static_cast(helper.M()); - const size_t N = static_cast(helper.N()); - const size_t K = static_cast(helper.K()); - const size_t lda = helper.Lda(false); - const size_t ldb = helper.Ldb(true); - - // TODO: implement with native kernel - std::vector data(max_len); - for (size_t i = 0; i < max_len; i++) { + std::vector data(batch_count); + for (size_t i = 0; i < batch_count; i++) { data[i].BIsPacked = false; data[i].A = a_data + helper.LeftOffsets()[i]; data[i].lda = lda; @@ -101,7 +132,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { data[i].beta = 0.0f; } MlasGemmBatch(CblasNoTrans, CblasTrans, - M, N, K, data.data(), max_len, thread_pool); + M, N, K, data.data(), batch_count, thread_pool); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h index 89e2351428d40..cbe536c6ce45a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -69,6 +69,7 @@ struct Flash_fwd_params : public Qkv_params { int seqlen_q_rounded = 0; int seqlen_k_rounded = 0; int d_rounded = 0; + int rotary_dim = 0; // The scaling factors for the kernel. float scale_softmax = 0.0; @@ -92,12 +93,26 @@ struct Flash_fwd_params : public Qkv_params { index_t knew_head_stride = 0; index_t vnew_head_stride = 0; + // The cos and sin matrices for rotary embedding. + void* __restrict__ rotary_cos_ptr = nullptr; + void* __restrict__ rotary_sin_ptr = nullptr; + + // The indices to index into the KV cache. + int* __restrict__ cache_batch_idx = nullptr; + + // Local window size + int window_size_left = -1; + int window_size_right = -1; + bool is_bf16 = false; bool is_causal = false; // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. bool is_seqlens_k_cumulative = true; + + bool is_rotary_interleaved = false; + int num_splits = 0; // For split-KV version const cudaDeviceProp* dprops = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 89a27c4d2b0d3..76190aad68fdb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -35,7 +35,9 @@ void set_params_fprop(Flash_fwd_params& params, void* softmax_lse_d, float softmax_scale, bool is_causal, - bool kv_bsnh = true) { + bool kv_bsnh = true, + int window_size_left = -1, + int window_size_right = -1) { // Set the pointers and strides. params.q_ptr = q; params.k_ptr = k; @@ -102,7 +104,21 @@ void set_params_fprop(Flash_fwd_params& params, params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; + // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates + // local and causal, meaning when we have local window size params.is_causal = is_causal; + if (is_causal && (window_size_left >= 0 || window_size_right != 0)) { + params.is_causal = false; + } + if (window_size_left < 0 && window_size_right >= 0) { + window_size_left = seqlen_k; + } + if (window_size_left >= 0 && window_size_right < 0) { + window_size_right = seqlen_k; + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.is_seqlens_k_cumulative = true; } @@ -227,7 +243,8 @@ Status mha_fwd(const cudaDeviceProp& dprops, int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - bool kv_bsnh) { + bool kv_bsnh, + int local_window_size) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); @@ -247,7 +264,9 @@ Status mha_fwd(const cudaDeviceProp& dprops, softmax_lse, softmax_scale, is_causal, - kv_bsnh); + kv_bsnh, + local_window_size, + is_causal ? 0 : -1); params.dprops = &dprops; params.knew_ptr = nullptr; params.vnew_ptr = nullptr; @@ -306,7 +325,10 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, nullptr, softmax_lse, softmax_scale, - is_causal); + is_causal, + true, + -1, + is_causal ? 0 : -1); params.dprops = &dprops; params.num_splits = 0; params.softmax_lseaccum_ptr = nullptr; @@ -347,11 +369,11 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, bool past_bsnh, // otherwise bnsh int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads - void* out_accum // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded -) { - if (seqlen_q == 1) { - is_causal = false; - } // causal=true is the same as causal=false in this case + void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int local_window_size) { + // if (seqlen_q == 1) { + // is_causal = false; + // } // causal=true is the same as causal=false in this case auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); @@ -372,7 +394,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, softmax_lse, softmax_scale, is_causal, - past_bsnh); + past_bsnh, + local_window_size, + is_causal ? 0 : -1); params.dprops = &dprops; if (k != nullptr && v != nullptr) { diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 58f4304251872..efc1f565c4fa0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -54,7 +54,8 @@ Status mha_fwd(const cudaDeviceProp& dprops, int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - bool kv_bsnh = true); + bool kv_bsnh = true, + int local_window_size = -1); Status mha_varlen_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, @@ -96,8 +97,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, bool past_bsnh, // otherwise bnsh int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads - void* out_accum = nullptr // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded -); + void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int local_window_size = -1); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index eb1c794d6df54..028233f66850f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -29,47 +29,6 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -CUTE_HOST_DEVICE auto -make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(cute::Layout, cute::Int>, - cute::Stride<_1, cute::Int>>{}, - make_layout(cute::size<2>(TileShape_MNK{}))); - - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTE_HOST_DEVICE auto -make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(cute::Layout, cute::Int>, - cute::Stride<_1, cute::Int>>{}, - // TODO: Shouldn't this be size<1>? - make_layout(cute::size<2>(TileShape_MNK{}))); - // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); } - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, Tensor1& scores_sum, Tensor2& acc_o, float softmax_scale_log2) { @@ -123,7 +82,7 @@ inline __device__ void write_softmax_to_gmem( //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; @@ -144,12 +103,14 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); - if (Is_causal) { - n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); // We exit early and write 0 to gO and gLSE. // Otherwise we might read OOB elements from gK and gV. - if (n_block_max <= 0) { + if (n_block_max <= n_block_min) { const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), @@ -197,7 +158,6 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - cute::Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), cute::Shape, cute::Int>{}, make_stride(params.q_row_stride, _1{})); @@ -332,9 +292,9 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = !Is_causal + constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 - : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) @@ -364,22 +324,22 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // We don't put the masking before the matmul S = Q K^T because we don't clear sK // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // can produce Inf / NaN. - if (!Is_causal) { + if (!Is_causal && !Is_local) { if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } } else { // I can't get the stride from idx_row - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM + get<0>(idx_row(0)), - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, - kNWarps * 16); + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); } flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); @@ -390,8 +350,8 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); // Convert scores from fp32 to fp16/bf16 cute::Tensor rP = flash::convert_type(scores); @@ -408,14 +368,14 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // This check is at the end of the loop since we always have at least 1 iteration - if (n_masking_steps > 1 && n_block <= 0) { + if (n_masking_steps > 1 && n_block <= n_block_min) { --n_block; break; } } // These are the iterations where we don't need masking on S - for (; n_block >= 0; --n_block) { + for (; n_block >= n_block_min; --n_block) { cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); flash::cp_async_wait<0>(); @@ -431,7 +391,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); @@ -441,8 +401,15 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi } // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); cute::Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) @@ -543,7 +510,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; @@ -572,11 +539,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons if (m_block * kBlockM >= binfo.actual_seqlen_q) return; const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; - const int n_block_min = n_split_idx * n_blocks_per_split; + const int n_block_min = !Is_local + ? n_split_idx * n_blocks_per_split + : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); - if (Is_causal) { + if (Is_causal || Is_local) { n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); } if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 // We exit early and write 0 to gOaccum and -inf to gLSEaccum. @@ -626,10 +595,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; // We move K and V to the last block. - const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; - const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, @@ -641,16 +609,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), Shape, Int>{}, make_stride(params.v_row_stride, _1{})); - // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, - // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. - // This maps to accessing the first 64 rows of knew_ptr. - Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), - Shape, Int>{}, - make_stride(params.knew_row_stride, _1{})); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } - Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), - Shape, Int>{}, - make_stride(params.vnew_row_stride, _1{})); Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQ{}); @@ -664,11 +622,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); typename Kernel_traits::TiledMma tiled_mma; @@ -732,17 +688,129 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons } // Prologue + // Copy from Knew to K, optionally apply rotary embedding. + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + if constexpr (Append_KV) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } + // if (cute::thread(8, 0)) { print_tensor(gCos); } + // if (cute::thread(0, 0)) { print_tensor(tRgCos); } + + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { + flash::copy_w_min_idx( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + if (params.rotary_dim == 0) { + flash::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } else { + if (params.is_rotary_interleaved) { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_interleaved( + tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); + tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); + } else { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_contiguous( + tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); + tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + } + } + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + } + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + if (n_block_max > n_block_copy_min) { + tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride; + tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride; + } + } + // Read Q from gmem to smem, optionally apply rotary embedding. Tensor tQrQ = make_fragment_like(tQgQ); - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); + if (!Append_KV || params.rotary_dim == 0) { + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + } else { + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. + // We do this by setting the row stride of gCos / gSin to 0. + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + if (params.is_rotary_interleaved) { + flash::copy_rotary_interleaved( + tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim); + } else { + flash::copy_rotary_contiguous( + tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim); + } + } int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy_2_sources( - gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); // flash::cp_async_wait<0>(); @@ -760,9 +828,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = !Is_causal + constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 - : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) @@ -770,32 +838,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons flash::cp_async_wait<0>(); __syncthreads(); - if constexpr (Append_KV) { - // if (cute::thread0()) { print(tKgK); } - // if (cute::thread0()) { print(tKsK); } - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - // __syncthreads(); - // if (cute::thread0()) { print(tKgK); } - // __syncthreads(); - } - // Advance gV if (masking_step > 0) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - if (Append_KV) { - tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); - } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads - flash::copy_2_sources( - gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); } cute::cp_async_fence(); @@ -810,15 +860,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // We don't put the masking before the matmul S = Q K^T because we don't clear sK // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // can produce Inf / NaN. - if (!Is_causal) { + if (!Is_causal && !Is_local) { if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } } else { - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, - kNWarps * 16); + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); } flash::cp_async_wait<0>(); @@ -826,26 +876,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } // __syncthreads(); - // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); } - if constexpr (Append_KV) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - } - if (n_block > n_block_min) { // Advance gK - // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); } tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - if (Append_KV) { - tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); - } - // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0, - binfo.seqlen_k_cache - (n_block - 1) * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -853,8 +887,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // We have key_padding_mask so we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } // Convert scores from fp32 to fp16/bf16 @@ -879,20 +913,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons clear(acc_s); flash::cp_async_wait<0>(); __syncthreads(); - if constexpr (Append_KV) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - } // Advance gV tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - if (Append_KV) { - tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); - } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); flash::gemm( @@ -901,22 +924,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons flash::cp_async_wait<0>(); __syncthreads(); - if constexpr (Append_KV) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - } if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - if (Append_KV) { - tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); - } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0, - binfo.seqlen_k_cache - (n_block - 1) * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -924,7 +935,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) @@ -1031,7 +1049,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1047,12 +1065,12 @@ inline __device__ void compute_attn(const Params& params) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_splitkv(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1061,24 +1079,23 @@ inline __device__ void compute_attn_splitkv(const Params& params) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void combine_attn_seqk_parallel(const Params& params) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; constexpr int kMaxSplits = 1 << Log_max_splits; - constexpr int kBlockM = 16; constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNThreads = Kernel_traits::kNThreads; static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); - // static_assert(kMaxSplits <= 8, "kMaxSplits must be <= 8 for now, will extend layer"); - static_assert(kBlockM == 16 || kBlockM == 32, "kBlockM must be 16 or 32"); - static_assert(Kernel_traits::kNThreads == 128, "We assume that each block has 128 threads"); + static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); // Shared memory. // kBlockM + 1 instead of kBlockM to reduce bank conflicts. @@ -1094,10 +1111,10 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { make_stride(params.b * params.h * params.seqlen_q, _1{})); Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); - constexpr int kNLsePerThread = (kMaxSplits * kBlockM + Kernel_traits::kNThreads - 1) / Kernel_traits::kNThreads; + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; // Read the LSE values from gmem and store them in shared memory, then tranpose them. - constexpr int kRowsPerLoadLSE = Kernel_traits::kNThreads / kBlockM; + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; #pragma unroll for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadLSE + tidx / kBlockM; @@ -1165,7 +1182,12 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), Shape, Int>{}, Stride, _1>{}); - typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); Tensor tOrO = make_tensor(shape(tOgOaccum)); @@ -1183,8 +1205,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } } -// Load Oaccum in then scale and accumulate to O -#pragma unroll 2 + // Load Oaccum in then scale and accumulate to O for (int split = 0; split < params.num_splits; ++split) { flash::copy( gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index 82dfa59b8f8e7..87d189a803f8a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -10,29 +10,30 @@ namespace onnxruntime { namespace flash { -template +template __global__ void flash_fwd_kernel(Flash_fwd_params params) { + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - flash::compute_attn(params); + flash::compute_attn(params); #else (void)params; #endif } -template +template __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - flash::compute_attn_splitkv(params); + flash::compute_attn_splitkv(params); #else (void)params; #endif } -template +template __global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static_assert(Log_max_splits >= 1); - flash::combine_attn_seqk_parallel(params); + flash::combine_attn_seqk_parallel(params); #else (void)params; #endif @@ -52,20 +53,25 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { const bool is_even_K = params.d == Kernel_traits::kHeadDim; BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - // Will only return softmax if dropout, to reduce compilation time. - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - // ORT_ENFORCE(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + // ORT_ENFORCE(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + }); }); }); } @@ -82,40 +88,46 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH(params.num_splits > 1, Split, [&] { - BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { - // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); - auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV, IsEvenKConst, Split, Append_KV > ; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - } - kernel<<>>(params); + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); + auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<>>(params); + }); }); }); }); }); }); if (params.num_splits > 1) { - dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16); + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { if (params.num_splits <= 2) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 4) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 8) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 16) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 32) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 64) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 128) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } }); } @@ -130,7 +142,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) template void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 32; + constexpr static int Headdim = 32; BOOL_SWITCH(params.is_causal, Is_causal, [&] { run_flash_fwd, Is_causal>(params, stream); }); @@ -138,7 +150,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 64; + constexpr static int Headdim = 64; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower // Using block size (64 x 256) is 27% slower for seqlen=2k @@ -174,8 +186,8 @@ void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 128; - const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + constexpr static int Headdim = 128; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. @@ -201,8 +213,8 @@ void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 160; - const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + constexpr static int Headdim = 160; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For A100, H100, 128 x 32 is the fastest. // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), @@ -241,12 +253,11 @@ void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { - constexpr size_t Headdim = 224; - constexpr size_t threshold = 2 * Headdim * (128 + 2 * 64); - size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; + constexpr static int Headdim = 224; + int max_smem_per_block = params.dprops->sharedMemPerBlockOptin; // printf("max_smem_per_block = %d\n", max_smem_per_block); BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if (max_smem_per_block >= threshold) { // 112 KB + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB run_flash_fwd, Is_causal>(params, stream); } else { run_flash_fwd, Is_causal>(params, stream); @@ -262,16 +273,14 @@ void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) { - constexpr size_t Headdim = 256; - constexpr size_t min_threshold = 2 * Headdim * (128 + 2 * 64); - constexpr size_t max_threshold = 4 * Headdim * (64 + 2 * 64); + constexpr static int Headdim = 256; size_t max_smem_per_sm = params.dprops->sharedMemPerMultiprocessor; size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For A100, we want to run with 128 x 64 (128KB smem). // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. - if (max_smem_per_block >= min_threshold && max_smem_per_sm < max_threshold) { + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { run_flash_fwd, Is_causal>(params, stream); } else { run_flash_fwd, Is_causal>(params, stream); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h index 134f159e258c4..1c0ed7f2fc2e8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h @@ -161,7 +161,14 @@ struct Flash_fwd_kernel_traits : public Base { cute::Stride<_16, _1>>>; using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, - cute::Layout>{})); // Val layout, 4 vals per store + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype(make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load }; // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h index 842edf3a98a86..8017f83bbb01d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h @@ -139,10 +139,11 @@ inline __device__ void apply_mask(Tensor& tensor, const int max_ } } -template -inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, - const int max_seqlen_k, const int row_idx_offset_, - const int max_seqlen_q, const int warp_row_stride) { +template +inline __device__ void apply_mask_local(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; @@ -155,14 +156,15 @@ inline __device__ void apply_mask_causal(Tensor& tensor, const i #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; - if (col_idx >= col_idx_limit) { + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } @@ -176,6 +178,15 @@ inline __device__ void apply_mask_causal(Tensor& tensor, const i } } +template +inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_, + max_seqlen_q, warp_row_stride, -1, 0); +} + template inline __device__ void apply_mask_causal_w_idx( Tensor& tensor, Tensor const& idx_rowcol, diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h index 02042e183f808..271112c5e890a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h @@ -307,7 +307,7 @@ template inline __device__ void copy(TiledCopy tiled_copy, Tensor const& S, Tensor& D, Tensor const& identity_MN, - Tensor const& predicate_K, int max_MN = 0) { + Tensor const& predicate_K, const int max_MN = 0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA @@ -334,65 +334,161 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor const //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void copy_2_sources(TiledCopy tiled_copy, Tensor const& S0, - Tensor const& S1, +inline __device__ void copy_w_min_idx(Tensor const& S, Tensor& D, Tensor const& identity_MN, Tensor const& predicate_K, - const int max_MN = 0, const int row_idx_switch = 0) { - CUTE_STATIC_ASSERT_V(rank(S0) == Int<3>{} && rank(S1) == Int<3>{}); + const int max_MN = 0, const int min_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S0) == size<0>(D) && size<0>(S1) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S0) == size<1>(D) && size<1>(S1) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S0) == size<2>(D) && size<2>(S1) == size<2>(D)); // MMA_K - // There's no case where !Clear_OOB_K && Clear_OOB_MN - static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); -// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); } -// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); } + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } #pragma unroll - for (int m = 0; m < size<1>(S0); ++m) { - auto& S = !Is_2_sources || get<0>(identity_MN(0, m, 0)) < row_idx_switch ? S0 : S1; - if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } #pragma unroll - for (int k = 0; k < size<2>(S0); ++k) { + for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { - cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + cute::copy(S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_rotary_interleaved(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K + static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + cute::copy(Cos(_, m, k), rCos(_, m, k)); + cute::copy(Sin(_, m, k), rSin(_, m, k)); + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS) / 2; ++i) { + float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); + float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); + S_fp32(2 * i) = real; + S_fp32(2 * i + 1) = imag; + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + } + cute::copy(rS(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { cute::clear(D(_, m, k)); } } - } else if (Clear_OOB_MN) { - cute::clear(D(_, m, _)); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void copy_w_min_idx(Tensor const& S, - Tensor& D, Tensor const& identity_MN, - Tensor const& predicate_K, - const int max_MN = 0, const int min_MN = 0) { +inline __device__ void copy_rotary_contiguous(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K -// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + Tensor rS_other = make_fragment_like(rS(_, 0, 0)); #pragma unroll for (int m = 0; m < size<1>(S); ++m) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { -// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } #pragma unroll for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || predicate_K(k)) { - cute::copy(S(_, m, k), D(_, m, k)); + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; + Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); + cute::copy(gS_other, rS_other); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } + Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); + Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); + cute::copy(gCos, rCos(_, m, k)); + cute::copy(gSin, rSin(_, m, k)); + // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor S_other_fp32 = convert_type(rS_other); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS); ++i) { + S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); } + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); } } } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index f21dff08e0350..93892169f6c79 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -44,9 +44,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); num_heads_ = static_cast(num_heads); kv_num_heads_ = static_cast(kv_num_heads); - is_unidirectional_ = true; - // left_padding_ = info.GetAttrOrDefault("left_padding_last_token", 0) == 1; is_past_bsnh_ = false; // info.GetAttrOrDefault("is_past_bsnh", 1) == 1; + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); scale_ = info.GetAttrOrDefault("scale", 0.0f); #if USE_FLASH_ATTENTION @@ -92,8 +91,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { is_past_bsnh_, scale_, device_prop.maxThreadsPerBlock)); - parameters.is_unidirectional = is_unidirectional_; - // parameters.left_padding = left_padding_; + parameters.local_window_size = local_window_size_; int sequence_length = parameters.sequence_length; TensorShapeVector output_shape(3); @@ -139,6 +137,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { bool use_memory_efficient_attention = !use_flash_attention && !disable_memory_efficient_attention_ && + local_window_size_ == -1 && (parameters.head_size & 7) == 0 && parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && @@ -222,6 +221,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.k = reinterpret_cast(k_buffer.get()); data.v = reinterpret_cast(v_buffer.get()); } + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (fmha_buffer != nullptr) { + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index aade0436dc141..54a8127e29e7b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -22,8 +22,7 @@ class GroupQueryAttention final : public CudaKernel { protected: int num_heads_; // number of attention heads int kv_num_heads_; // different for k and v for group query attention - // bool left_padding_; // shifts last token to end of buffer - bool is_unidirectional_; // causal + int local_window_size_; bool is_past_bsnh_; float scale_; bool disable_flash_attention_; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 2d158155eeba9..b22ccb68c1e7b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -468,55 +468,6 @@ Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, i return CUDA_CALL(cudaGetLastError()); } -// // Kernel to append new kv to kv buffer in place -// template -// __global__ void LeftPadLast(const int max_seqlen, -// T* kv_buff, -// const int* seqlens_k) { // refers to kv buff; otherwise bnsh -// const int h = threadIdx.x; -// const int n = blockIdx.x; -// const int b = blockIdx.y; - -// const int num_heads = gridDim.x; -// const int H = blockDim.x; - -// const int present_batch_stride = max_seqlen * num_heads * H; -// const int present_row_stride = num_heads * H; -// const int present_head_stride = H; - -// // kv_buff: BTNH or BNTH with buffered memory for new -// // new_kv: BLNH - -// const int s = seqlens_k[b]; - -// const int in_offset = b * present_batch_stride + s * present_row_stride + n * present_head_stride + h; -// const int out_offset = b * present_batch_stride + (max_seqlen - 1) * present_row_stride + n * present_head_stride + h; -// kv_buff[out_offset] = kv_buff[in_offset]; -// } - -// // Concat new to kv buffer in place -// template -// Status LaunchLeftPadLast(contrib::GroupQueryAttentionParameters& parameters, -// GroupQueryAttentionData& data, -// cudaStream_t stream, -// const int max_threads_per_block) { -// const int batch_size = parameters.batch_size; -// const int sequence_length = parameters.sequence_length; -// const int num_heads = parameters.num_heads; -// const int head_size = parameters.head_size; - -// // Indicates past sequence_length of each sequence -// const int* seqlens_k = reinterpret_cast(data.seqlens_k); - -// const int H = head_size / 4; -// const dim3 grid(num_heads, batch_size, 1); -// const dim3 block(H, 1, 1); -// LeftPadLast<<>>(sequence_length, -// reinterpret_cast(data.output), -// seqlens_k); -// return CUDA_CALL(cudaGetLastError()); -// } - ////////// Launch Kernels #if USE_FLASH_ATTENTION @@ -541,7 +492,7 @@ Status FlashAttention( void* key = reinterpret_cast(const_cast(data.key)); void* value = reinterpret_cast(const_cast(data.value)); - bool is_causal = parameters.is_unidirectional; + bool is_causal = true; // Note: seqlens_k is past sequence length for flash if (parameters.is_prompt) { @@ -579,7 +530,7 @@ Status FlashAttention( seqlens_k, batch_size, num_heads, kv_num_heads, head_size, sequence_length, present_sequence_length, kv_sequence_length, scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum))); + reinterpret_cast(data.out_accum), parameters.local_window_size)); } else { // Not share buffer case // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient @@ -611,13 +562,9 @@ Status FlashAttention( seqlens_k, batch_size, num_heads, kv_num_heads, head_size, sequence_length, present_sequence_length, 0, scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum))); + reinterpret_cast(data.out_accum), parameters.local_window_size)); } - // if (parameters.left_padding && parameters.is_prompt) { - // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); - // } - DUMP_TENSOR_INIT(); DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); @@ -704,9 +651,11 @@ Status EfficientAttention( p.max_sequence_length = present_sequence_length; p.qk_head_size = head_size; p.v_head_size = head_size; - p.causal = parameters.is_unidirectional; + p.causal = true; p.scale = scale; p.seqlen_k_ptr = data.seqlens_k_total; // Note: seqlens_k is total sequence length for efficient + p.seqstart_q_ptr = nullptr; + p.seqstart_k_ptr = nullptr; p.query = query; p.key = key; p.value = value; @@ -721,10 +670,6 @@ Status EfficientAttention( p.has_custom_right_padding = true; run_memory_efficient_attention(p); - // if (parameters.left_padding && parameters.is_prompt) { - // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); - // } - DUMP_TENSOR_INIT(); DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc index b4b5dac1fbe19..2d12e975d88d7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -74,7 +74,8 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { parameters.max_sequence_length, parameters.position_ids_format, interleaved, - device_prop.maxThreadsPerBlock); + device_prop.maxThreadsPerBlock, + parameters.transposed); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index c54e72dcfce13..e1b83bd8caf54 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -27,7 +27,10 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int num_heads, const int head_size, const int position_ids_format, - const bool interleaved) { + const bool interleaved, + const int batch_stride, + const int seq_stride, + const int head_stride) { // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length // Use .x in innermost loop to access global memory efficiently @@ -37,11 +40,10 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int i = threadIdx.x; - const int block_offset = b * sequence_length * num_heads + s * num_heads + n; - const int data_offset = block_offset * head_size; + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - const T* input_data = input + data_offset; - T* output_data = output + data_offset; + const T* input_data = input + block_offset; + T* output_data = output + block_offset; // Cache is (M, H/2) const int half_head_size = head_size / 2; @@ -83,7 +85,8 @@ Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block) { + const int max_threads_per_block, + const bool transposed) { constexpr int smem_size = 0; const dim3 grid(num_heads, sequence_length, batch_size); @@ -94,10 +97,22 @@ Status LaunchRotaryEmbeddingKernel( // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` // instead. This will require kernel changes to support. + // Default input tensor shape is [batch, seq, hidden_size] + int head_stride = head_size; + int seq_stride = num_heads * head_stride; + int batch_stride = sequence_length * seq_stride; + if (transposed) { + // When transposed, input tensor shape is [batch, num_heads, seq, head_size] + seq_stride = head_size; + head_stride = sequence_length * seq_stride; + batch_stride = num_heads * head_stride; + } + assert(head_size <= max_threads_per_block); RotaryEmbeddingBSNH<<>>( output, input, cos_cache, sin_cache, position_ids, - sequence_length, num_heads, head_size, position_ids_format, interleaved + sequence_length, num_heads, head_size, position_ids_format, interleaved, + batch_stride, seq_stride, head_stride ); return CUDA_CALL(cudaGetLastError()); @@ -117,7 +132,8 @@ template Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block); + const int max_threads_per_block, + const bool transposed); template Status LaunchRotaryEmbeddingKernel( cudaStream_t stream, @@ -133,7 +149,8 @@ template Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block); + const int max_threads_per_block, + const bool transposed); } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h index 29ff48a8ad0fb..ee1ccc43dcbff 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -24,7 +24,8 @@ Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block); + const int max_threads_per_block, + const bool transposed); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 6162895d00d79..108eea1a73fe9 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -70,6 +70,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); @@ -119,6 +121,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MatMulBnb4); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu); @@ -260,6 +263,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -309,6 +314,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc index 251850f621361..6cdccdb1becb1 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc @@ -14,17 +14,23 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_KERNEL() \ - ONNX_OPERATOR_KERNEL_EX( \ - GemmFloat8, \ - kMSDomain, \ - 1, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("TA", BuildKernelDefConstraints()) \ - .TypeConstraint("TB", BuildKernelDefConstraints()) \ - .TypeConstraint("TR", BuildKernelDefConstraints()) \ - .TypeConstraint("TS", BuildKernelDefConstraints()), \ +#if !defined(DISABLE_FLOAT8_TYPES) +#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() +#else +#define GEMM_FLOAT8_CONSTRAINTS BuildKernelDefConstraints() +#endif + +#define REGISTER_KERNEL() \ + ONNX_OPERATOR_KERNEL_EX( \ + GemmFloat8, \ + kMSDomain, \ + 1, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("TA", GEMM_FLOAT8_CONSTRAINTS) \ + .TypeConstraint("TB", GEMM_FLOAT8_CONSTRAINTS) \ + .TypeConstraint("TR", GEMM_FLOAT8_CONSTRAINTS) \ + .TypeConstraint("TS", BuildKernelDefConstraints()), \ GemmFloat8); REGISTER_KERNEL() @@ -38,7 +44,7 @@ GemmFloat8::GemmFloat8(const OpKernelInfo& info) : CudaKernel(info) { alpha_ = info.GetAttrOrDefault("alpha", 1); beta_ = info.GetAttrOrDefault("beta", 0); -#if (CUDA_VERSION <= 12000) +#if (CUDA_VERSION < 12000) ORT_ENFORCE(beta_ == 0, "CUDA < 12.0 does not support bias, beta must be 0."); #endif diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index df25342342cd5..56b541f5256bf 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -28,7 +28,7 @@ int32_t TypeSize(int32_t element_type) { case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: return 2; -#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080)) +#if !defined(DISABLE_FLOAT8_TYPES) case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: return 1; @@ -97,12 +97,16 @@ Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { } auto first_type = input_A->GetElementType(); +#if !defined(DISABLE_FLOAT8_TYPES) bool is_float8 = first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2; if (!is_float8) +#endif return ComputeRowMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, input_C, scale_A, scale_B, scale_Y); +#if !defined(DISABLE_FLOAT8_TYPES) return ComputeColMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, input_C, scale_A, scale_B, scale_Y); +#endif } Status GemmFloat8::ComputeRowMajor( @@ -197,10 +201,15 @@ Status GemmFloat8::ComputeGemm( switch (d_cuda_type) { case CUDA_R_16F: switch (a_cuda_type) { +#if !defined(DISABLE_FLOAT8_TYPES) +#if CUDA_VERSION < 11080 +#error CUDA_R_8F_E4M3 (float 8 types) is defined with CUDA>=11.8. Set flag DISABLE_FLOAT8_TYPES. +#endif case CUDA_R_8F_E4M3: case CUDA_R_8F_E5M2: compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; break; +#endif default: compute_type = CUBLAS_COMPUTE_32F_FAST_16F; break; @@ -267,7 +276,7 @@ Status GemmFloat8::ComputeGemm( sizeof(p_scale_b))); // float 8 -#if CUDA_VERSION >= 11080 +#if !defined(DISABLE_FLOAT8_TYPES) if (dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN || dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2) { // For FP8 output, cuBLAS requires C_type to be same as bias_type @@ -280,15 +289,14 @@ Status GemmFloat8::ComputeGemm( CUBLAS_RETURN_IF_ERROR( cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); } - } else { - CUBLAS_RETURN_IF_ERROR( - cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); - } #else - // An output is still needed but it is not initialized. CUBLAS_RETURN_IF_ERROR( cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); #endif + } else { + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); + } if (row_major_compute) { cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW; @@ -345,7 +353,7 @@ Status GemmFloat8::ComputeGemm( ". Check NVIDIA documentation to see what combination is valid: ", "https://docs.nvidia.com/cuda/cublas/" "index.html?highlight=cublasLtMatmulAlgoGetHeuristic#" - "cublasltmatmulalgogetheuristic."); + "cublasltmatmulalgogetheuristic. CUDA>=11.8 is required to use float 8 types."); void* workspace = nullptr; if (workspaceSize > 0) { @@ -381,7 +389,8 @@ Status GemmFloat8::ComputeGemm( ", shape_A=", shape_A[0], "x", shape_A[1], ", shape_B=", shape_B[0], "x", shape_B[1], ", M=", M, ", N=", N, ", K=", K, ", lda=", lda, ", ldb=", ldb, ", ldd=", ldd, ", workspaceSize=", workspaceSize, - ", rowMajorCompute=", (row_major_compute ? 1 : 0), "."); + ", rowMajorCompute=", (row_major_compute ? 1 : 0), + ". CUDA>=11.8 is required to use float 8 types."); if (workspaceSize > 0) { CUDA_RETURN_IF_ERROR(cudaFree(workspace)); diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h new file mode 100644 index 0000000000000..86136ea244e23 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "core/providers/cuda/shared_inc/cuda_call.h" +#include "cutlass/device_kernel.h" + +using namespace onnxruntime; + +namespace ort_fastertransformer { + +template +inline int compute_occupancy_for_kernel() { + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size > (48 << 10)) { + cudaError_t status = + cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (status == cudaError::cudaErrorInvalidValue) { + // Clear the error bit since we can ignore this. + // This should mean that smem_size > cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an + // occupancy of 0. This will cause the heuristic to ignore this configuration. + status = cudaGetLastError(); + return 0; + } + CUDA_CALL_THROW(status); + } + + int max_active_blocks = -1; + CUDA_CALL_THROW(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::Kernel, + GemmKernel::kThreadCount, smem_size)); + + return max_active_blocks; +} + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc new file mode 100644 index 0000000000000..5d4c6793ec995 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cutlass_heuristic.h" + +#include +#include +#include + +namespace ort_fastertransformer { + +struct TileShape { + int m; + int n; +}; + +TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) { + switch (tile_config) { + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + return TileShape{32, 128}; + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + return TileShape{64, 128}; + case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + return TileShape{128, 128}; + default: + ORT_THROW("[FT Error][get_grid_shape_for_config] Invalid config"); + } +} + +bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, const TileShape tile_shape, + const int split_k_factor, const size_t workspace_bytes, const bool is_weight_only) { + // All tile sizes have a k_tile of 64. + static constexpr int k_tile = 64; + + // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k + if (is_weight_only) { + if ((k % k_tile) != 0) { + return false; + } + + if ((k % split_k_factor) != 0) { + return false; + } + + const int k_elements_per_split = static_cast(k / split_k_factor); + if ((k_elements_per_split % k_tile) != 0) { + return false; + } + } + + // Check that the workspace has sufficient space for this split-k factor + const int ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); + const int ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); + const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + + if (required_ws_bytes > workspace_bytes) { + return false; + } + + return true; +} + +std::vector get_candidate_tiles(const bool is_weight_only, const bool simt_configs_only) { + std::vector simt_configs{CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; + + std::vector square_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64}; + + std::vector quant_B_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; + + const std::vector allowed_configs = is_weight_only ? quant_B_configs : square_configs; + return simt_configs_only ? simt_configs : allowed_configs; +} + +std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only) { + std::vector tiles = get_candidate_tiles(is_weight_only, simt_configs_only); + + std::vector candidate_configs; + const int min_stages = 2; + const int max_stages = sm >= 80 ? 4 : 2; + + for (const auto& tile_config : tiles) { + for (int stages = min_stages; stages <= max_stages; ++stages) { + CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages}; + candidate_configs.push_back(config); + } + } + + return candidate_configs; +} + +CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, + const std::vector& occupancies, const int64_t m, + const int64_t n, const int64_t k, const int64_t, + const int split_k_limit, const size_t workspace_bytes, + const int multi_processor_count, const int is_weight_only) { + if (occupancies.size() != candidate_configs.size()) { + ORT_THROW( + "[FT Error][estimate_best_config_from_occupancies] occpancies and " + "candidate configs vectors must have equal length."); + } + + CutlassGemmConfig best_config; + // Score will be [0, 1]. The objective is to minimize this score. + // It represents the fraction of SM resources unused in the last wave. + float config_score = 1.0f; + int config_waves = INT_MAX; + int current_m_tile = 0; + + const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; + for (int ii = 0; ii < candidate_configs.size(); ++ii) { + CutlassGemmConfig candidate_config = candidate_configs[ii]; + TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); + int occupancy = occupancies[ii]; + + if (occupancy == 0) { + continue; + } + + // Keep small tile sizes when possible. + if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile && + current_m_tile < tile_shape.m) { + continue; + } + + const int ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); + const int ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); + + for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) { + if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) { + const int ctas_per_wave = occupancy * multi_processor_count; + const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; + + const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; + const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave); + const float current_score = float(num_waves_total) - num_waves_fractional; + + const float score_slack = 0.1f; + if (current_score < config_score || + ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) { + config_score = current_score; + config_waves = num_waves_total; + SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = + CutlassGemmConfig{candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; + current_m_tile = tile_shape.m; + } else if (current_score == config_score && + (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor || + current_m_tile < tile_shape.m)) { + // Prefer deeper pipeline or smaller split-k + SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = + CutlassGemmConfig{candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; + current_m_tile = tile_shape.m; + config_waves = num_waves_total; + } + } + } + } + + if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) { + ORT_THROW("[FT Error] Heurisitc failed to find a valid config."); + } + + return best_config; +} + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h new file mode 100644 index 0000000000000..e70efe0503b55 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "ft_gemm_configs.h" + +#include +#include +#include + +#include "core/common/common.h" + +using namespace onnxruntime; + +namespace ort_fastertransformer { + +std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only); + +CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, + const std::vector& occupancies, const int64_t m, + const int64_t n, const int64_t k, const int64_t num_experts, + const int split_k_limit, const size_t workspace_bytes, + const int multi_processor_count, const int is_weight_only); + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h new file mode 100644 index 0000000000000..78d206bf1d9bc --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * @file epilogue_helpers.h + * + * This file includes types for the epilogues. The empty structs exist so we can signal to template + * code the type of epilogue we want to run, and let the underlying code specify the details such as + * element types, accumulator type and elements per vector access. + * + */ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/half.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_silu.h" + +namespace cutlass { +namespace epilogue { +namespace thread { + +__forceinline__ __device__ float copysignf_pos(float a, float b) { + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} + +__forceinline__ __device__ float tanh_opt(float x) { +#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) + const float exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#else + return fast_tanh(x); +#endif +} + +template <> +struct GELU_taylor { + static const bool kIsHeavy = true; + CUTLASS_DEVICE + float operator()(float const& z) const { + float k0 = float(0.7978845608028654); + float k1 = float(0.044715); + + return float( + cutlass::constants::half() * z * + (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); + } + + using Params = LinearCombinationGenericParams; + + CUTLASS_DEVICE + float operator()(float const& scalar, Params const& params_) const { return this->operator()(scalar); } +}; + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +namespace ort_fastertransformer { + +struct EpilogueOpBiasSilu {}; + +struct EpilogueOpBiasReLU {}; + +struct EpilogueOpBiasFtGelu {}; + +struct EpilogueOpBias {}; + +struct EpilogueOpNoBias {}; + +template +struct Epilogue {}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric< + cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator, + ElementAccumulator, cutlass::epilogue::thread::ScaleType::NoBetaScaling, + cutlass::FloatRoundStyle::round_to_nearest, true>; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +template +struct Epilogue { + using Op = + cutlass::epilogue::thread::LinearCombination; +}; + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h new file mode 100644 index 0000000000000..a5faad423fad9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace ort_fastertransformer { +// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape +// in the kernel layout details when doing weight only quantization. +enum class CutlassTileConfig { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // SiMT config + CtaShape128x128x8_WarpShape64x64x8, + + // TensorCore configs CTA_N = 128, CTA_K = 64 + // Warp configs for M=32 + CtaShape32x128x64_WarpShape32x32x64, + + // Warp configs for M=64 + CtaShape64x128x64_WarpShape32x64x64, + CtaShape64x128x64_WarpShape64x32x64, + + // Warp configs for M=128 + CtaShape128x128x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape128x32x64 +}; + +enum class SplitKStyle { + NO_SPLIT_K, + SPLIT_K_SERIAL, + // SPLIT_K_PARALLEL // Not supported yet +}; + +struct CutlassGemmConfig { + CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; + SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; + int split_k_factor = -1; + int stages = -1; +}; + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h new file mode 100644 index 0000000000000..311ed323cb90c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h @@ -0,0 +1,79 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Scheduler for grouped GEMM +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/matrix_coord.h" + +#include "moe_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct GemmMoeProblemVisitor + : public MoeProblemVisitor, ThreadblockShape, + GroupScheduleMode_, PrefetchTileCount, ThreadCount> { + static bool const kTransposed = Transposed; + + using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using Base = + MoeProblemVisitor; + using Params = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + + // + // Methods + // + CUTLASS_DEVICE + GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) + : Base(params_, shared_storage_, block_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h new file mode 100644 index 0000000000000..eb33a98e4246f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is + quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices + to be consumed by CUTLASS. + + Note that for int4, ThreadBlockK MUST be 64. + + */ + +#pragma once + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/platform/platform.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct LayoutDetailsB {}; + +// Volta specialiations. Volta will dequantize before STS, so we need a different operator +template +struct LayoutDetailsB { + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = 8; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. +// TODO - Switch this to column major for weights since gemms should be more performant. +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct MixedGemmArchTraits {}; + +template +struct MixedGemmArchTraits { + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassSimt; + using AccType = float; + using LayoutB = cutlass::layout::RowMajor; + + static constexpr int ElementsPerAccessA = 1; + static constexpr int ElementsPerAccessB = 1; + static constexpr int ElementsPerAccessC = 1; + static constexpr int ThreadblockK = 8; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// ========================= Volta Traits =========================== +// Volta will always dequantize after the global memory load. +// This will instantiate any HMMA tensorcore kernels for Volta. +template +struct MixedGemmArchTraits< + TypeA, TypeB, cutlass::arch::Sm70, + typename cutlass::platform::enable_if::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Turing Traits ============================== +template +struct MixedGemmArchTraits< + TypeA, TypeB, cutlass::arch::Sm75, + typename cutlass::platform::enable_if::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Ampere Traits ============================== +template +struct MixedGemmArchTraits< + TypeA, TypeB, cutlass::arch::Sm80, + typename cutlass::platform::enable_if::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Operator = typename LayoutDetails::Operator; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h new file mode 100644 index 0000000000000..bfe30b71170d8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h @@ -0,0 +1,463 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" + +#include "gemm_moe_problem_visitor.h" +#include "tile_interleaved_layout.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms. +// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global. +template +using void_t = void; + +template +struct use_dq_gemm : platform::false_type {}; + +template +struct use_dq_gemm> : platform::true_type {}; + +// SFINAE overload for dequantizing gemm +template ::value, bool>::type = true> +CUTLASS_DEVICE static void run_mma(Mma mma, int gemm_k_iterations, typename Mma::FragmentC& accum, + typename Mma::IteratorA iterator_A, typename Mma::IteratorB iterator_B, + typename Mma::FragmentC const& src_accum, ElementScale* weight_scale_ptr, + MatrixCoord scale_extent, const int thread_idx, MatrixCoord tb_offset_scale) { + typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()), weight_scale_ptr, + scale_extent, thread_idx, tb_offset_scale); + + mma(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_scale, src_accum); +} + +// SFINAE overload for normal gemm. This completely ignores the scale parameters +template ::value, bool>::type = true> +CUTLASS_DEVICE static void run_mma(Mma mma, int gemm_k_iterations, typename Mma::FragmentC& accum, + typename Mma::IteratorA iterator_A, typename Mma::IteratorB iterator_B, + typename Mma::FragmentC const& src_accum, ElementScale* weight_scale_ptr, + MatrixCoord scale_extent, const int thread_idx, MatrixCoord tb_offset_scale) { + mma(gemm_k_iterations, accum, iterator_A, iterator_B, src_accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MoeFCGemm { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = false; + + // Optional transpose + using MapArguments = + kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + static_assert(!kTransposed, "Transpose problem not supported"); + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = + GemmMoeProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // + + int problem_count; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + + int64_t* total_rows_before_expert; + int64_t gemm_n; + int64_t gemm_k; + + // Only used by device-level operator + GemmCoord* host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0), + threadblock_count(0), + ptr_A(nullptr), + ptr_B(nullptr), + weight_scales(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + total_rows_before_expert(nullptr), + gemm_n(0), + gemm_k(0), + host_problem_sizes(nullptr) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(int problem_count, int threadblock_count, typename EpilogueOutputOp::Params output_op, + const ElementA* ptr_A, const ElementB* ptr_B, const ElementScale* weight_scales, const ElementC* ptr_C, + ElementC* ptr_D, int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, + GemmCoord* host_problem_sizes = nullptr) + : problem_count(problem_count), + threadblock_count(threadblock_count), + output_op(output_op), + ptr_A(const_cast(ptr_A)), + ptr_B(const_cast(ptr_B)), + weight_scales(const_cast(weight_scales)), + ptr_C(const_cast(ptr_C)), + ptr_D(ptr_D), + total_rows_before_expert(total_rows_before_expert), + gemm_n(gemm_n), + gemm_k(gemm_k), + host_problem_sizes(nullptr) { + if (platform::is_same::value || platform::is_same::value) { + assert(weight_scales); + } + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() : ptr_A(nullptr), ptr_B(nullptr), weight_scales(nullptr), ptr_C(nullptr), ptr_D(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_visitor(args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, + tile_count), + threadblock_count(args.threadblock_count), + output_op(args.output_op), + ptr_A(args.ptr_A), + ptr_B(args.ptr_B), + weight_scales(args.weight_scales), + ptr_C(args.ptr_C), + ptr_D(args.ptr_D) {} + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) { + problem_visitor = typename ProblemVisitor::Params(args.total_rows_before_expert, args.gemm_n, args.gemm_k, + args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + weight_scales = args.weight_scales; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + MoeFCGemm() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { return Status::kSuccess; } + + static Status can_implement(Arguments const& args) { + if (args.weight_scales != nullptr) { + CUTLASS_TRACE_HOST( + "MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t"); + return Status::kInvalid; + } + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + + // The dummy template parameter is not used and exists so that we can compile this code using + // a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in + // a namespace + template + struct KernelRunner { + CUTLASS_DEVICE + static void run_kernel(Params const& params, SharedStorage& shared_storage) { CUTLASS_NOT_IMPLEMENTED(); } + }; + + template + struct KernelRunner { + CUTLASS_DEVICE + static void run_kernel(Params const& params, SharedStorage& shared_storage) { + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + static_assert(platform::is_same::value && kInterleave == 1 || + platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // + // Problem visitor. + // + ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + const int64_t gemm_k = params.problem_visitor.gemm_k; + const int64_t gemm_n = params.problem_visitor.gemm_n; + int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_offset(int(cta_idx / grid_shape.n()) * Mma::Shape::kM, + int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0); + + // Load element pointers. Exchange pointers and strides if working on the transpose + const int64_t rows_to_jump = + problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; + typename LayoutA::LongIndex ldm_A = gemm_k; + + char* byte_ptr_B = ((char*)params.ptr_B) + problem_idx * bytes_per_expert_matrix; + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); + typename LayoutB::LongIndex ldm_B = + platform::is_same::value ? gemm_n : gemm_k * kInterleave; + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + 0, + }; + + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; + + cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B, + {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, + tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n(); + run_mma(mma, gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, weight_scale_ptr, + {1, problem_size.n()}, thread_idx, tb_offset_scale); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementC* ptr_C = reinterpret_cast(params.ptr_C) + problem_idx * gemm_n; + ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; + + LayoutC layout_C(0); + LayoutC layout_D(gemm_n); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params_C, ptr_C, problem_size.mn(), thread_idx, + threadblock_offset.mn()); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params_D, ptr_D, problem_size.mn(), thread_idx, + threadblock_offset.mn()); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // Next tile + problem_visitor.advance(gridDim.x); + } + } + }; + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h new file mode 100644 index 0000000000000..60608f462fde5 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "ft_gemm_configs.h" + +namespace ort_fastertransformer { + +enum class ActivationType { Gelu, + Relu, + Silu, + GeGLU, + ReGLU, + SiGLU, + Identity, + InvalidType }; + +template +class MoeGemmRunner { + public: + MoeGemmRunner(); + + void initialize(int sm); + + void moe_gemm_bias_act(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, ActivationType activation_type, cudaStream_t stream); + + void moe_gemm(const T* A, const WeightType* B, const T* weight_scales, T* C, int64_t* total_rows_before_expert, + int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream); + + private: + template + void dispatch_to_arch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy = nullptr); + + template + void run_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cudaStream_t stream); + + private: + int sm_; + int multi_processor_count_; +}; + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu new file mode 100644 index 0000000000000..1d9a249db4237 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moe_gemm_kernels_template.h" + +namespace ort_fastertransformer { +template class MoeGemmRunner; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu new file mode 100644 index 0000000000000..7b250e6ca9060 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moe_gemm_kernels_template.h" + +namespace ort_fastertransformer { +template class MoeGemmRunner; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h new file mode 100644 index 0000000000000..66950c9b65970 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -0,0 +1,428 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Ignore CUTLASS warnings about type punning +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/arch.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" + +#include "compute_occupancy.h" +#include "epilogue_helpers.h" +#include "layout_traits_helper.h" +#include "moe_cutlass_kernel.h" + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +#include "cutlass_heuristic.h" +#include "moe_gemm_kernels.h" + +#include +#include +#include +#include + +namespace ort_fastertransformer { + +// ============================= Variable batched Gemm things =========================== +template +void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, const int multi_processor_count, + cudaStream_t stream, int* kernel_occupancy = nullptr) { + if (gemm_config.split_k_style != SplitKStyle::NO_SPLIT_K) { + ORT_THROW("[FT Error][MoeGemm] Grouped gemm does not support split-k"); + } + + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for half, float"); + + static_assert(cutlass::platform::is_same::value || + cutlass::platform::is_same::value || + cutlass::platform::is_same::value, + ""); + + // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + using ElementType_ = + typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; + using ElementType = ElementType_; + + using CutlassWeightType_ = + typename cutlass::platform::conditional::value, cutlass::half_t, + WeightType>::type; + using CutlassWeightType = CutlassWeightType_; + + // We need separate config for each architecture since we will target different tensorcore instructions. For float, + // we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + using EpilogueOp = + typename Epilogue::Op; + + // Finally, set up the kernel. + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped< + ElementType, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessA, + CutlassWeightType, typename MixedGemmArchTraits::LayoutB, cutlass::ComplexTransform::kNone, + MixedGemmArchTraits::ElementsPerAccessB, ElementType, cutlass::layout::RowMajor, ElementAccumulator, + typename MixedGemmArchTraits::OperatorClass, arch, ThreadblockShape, WarpShape, + typename MixedGemmArchTraits::InstructionShape, EpilogueOp, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, Stages, + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, typename MixedGemmArchTraits::Operator>::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; + + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + + if (kernel_occupancy != nullptr) { + *kernel_occupancy = compute_occupancy_for_kernel(); + return; + } + int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); + if (occupancy == 0) { + ORT_THROW("[FT Error][MoE Runner] GPU lacks the shared memory resources to run GroupedGEMM kernel"); + } + const int threadblock_count = multi_processor_count * occupancy; + + typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), ElementAccumulator(0.f)); + + typename GemmGrouped::Arguments args( + num_experts, threadblock_count, epilogue_op, reinterpret_cast(A), + reinterpret_cast(B), reinterpret_cast(weight_scales), + reinterpret_cast(biases), reinterpret_cast(C), total_rows_before_expert, gemm_n, + gemm_k); + + GemmGrouped gemm; + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) { + std::string err_msg = + "MoEFC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); + ORT_THROW("[FT Error][MoE Runner] " + err_msg); + } + + auto init_status = gemm.initialize(args); + if (init_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to initialize cutlass variable batched gemm. Error: " + + std::string(cutlassGetStatusString(init_status)); + ORT_THROW("[FT Error][MoE Runner] " + err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) { + std::string err_msg = + "Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status)); + ORT_THROW("[FT Error][MoE Runner] " + err_msg); + } +} + +template +struct dispatch_stages { + static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) { + std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); + ORT_THROW("[FT Error][dispatch_stages::dispatch] " + err_msg); + } +}; + +template +struct dispatch_stages { + static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) { + generic_moe_gemm_kernelLauncher( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + } +}; + +template +struct dispatch_stages 2)>::type> { + static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) { + generic_moe_gemm_kernelLauncher(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count, stream, occupancy); + } +}; + +template +void dispatch_gemm_config(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) { + switch (gemm_config.stages) { + case 2: + using DispatcherStages2 = dispatch_stages; + DispatcherStages2::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case 3: + using DispatcherStages3 = dispatch_stages; + DispatcherStages3::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case 4: + using DispatcherStages4 = dispatch_stages; + DispatcherStages4::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + default: + std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages); + ORT_THROW("[FT Error][MoE][dispatch_gemm_config] " + err_msg); + break; + } +} + +// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32. +// This overload is only enabled when T == WeightType. +template < + typename T, typename WeightType, typename arch, typename EpilogueTag, + typename std::enable_if::value && std::is_same::value>::type* = nullptr> +void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, int sm_version, + int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { + switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, + total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, C, + total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, + total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::Undefined: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config undefined."); + break; + case CutlassTileConfig::ChooseWithHeuristic: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config should have already been set by heuristic."); + break; + default: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] Config is invalid for same type MoE tensorop GEMM."); + break; + } +} + +// Tensorop GEMM overload +// Overload for quantize MoE GEMMs. We disable some warp configs here since they will not be used and we can improve +// compile time +template < + typename T, typename WeightType, typename arch, typename EpilogueTag, + typename std::enable_if::value && !std::is_same::value>::type* = nullptr> +void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, int sm_version, + int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { + switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, + total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, + total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<128, 32, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::Undefined: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config undefined."); + break; + case CutlassTileConfig::ChooseWithHeuristic: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config should have already been set by heuristic."); + break; + default: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] Config is invalid for mixed type tensorop GEMM."); + break; + } +} + +// This overload will handle simt gemms. It is disabled via SFINAE for tensorop. +template ::value>::type* = nullptr> +void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, int sm_version, + int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { + switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, C, + total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::Undefined: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] gemm config undefined."); + break; + case CutlassTileConfig::ChooseWithHeuristic: + ORT_THROW( + "[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] gemm config should have already been set by heuristic."); + break; + default: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] Unsupported config for float MoE gemm."); + break; + } +} + +template +MoeGemmRunner::MoeGemmRunner() {} + +template +void MoeGemmRunner::initialize(int sm_version) { + int device{-1}; + cudaGetDevice(&device); + sm_ = sm_version; + cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device); +} + +template +template +void MoeGemmRunner::dispatch_to_arch(const T* A, const WeightType* B, + const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, cudaStream_t stream, + int* occupancy) { + if (sm_ >= 70 && sm_ < 75) { + dispatch_moe_gemm_to_cutlass( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + sm_, multi_processor_count_, stream, occupancy); + } else if (sm_ >= 75 && sm_ < 80) { + dispatch_moe_gemm_to_cutlass( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + sm_, multi_processor_count_, stream, occupancy); + } else if (sm_ >= 80 && sm_ < 90) { + dispatch_moe_gemm_to_cutlass( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + sm_, multi_processor_count_, stream, occupancy); + } else { + ORT_THROW("[FT Error][MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); + } +} + +template +template +void MoeGemmRunner::run_gemm(const T* A, const WeightType* B, const T* weight_scales, + const T* biases, T* C, int64_t* total_rows_before_expert, + int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, cudaStream_t stream) { + static constexpr bool is_weight_only = !std::is_same::value; + static constexpr bool only_simt_configs = std::is_same::value; + std::vector candidate_configs = get_candidate_configs(sm_, is_weight_only, only_simt_configs); + std::vector occupancies(candidate_configs.size()); + + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { + dispatch_to_arch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, candidate_configs[ii], stream, &occupancies[ii]); + } + + static constexpr int workspace_bytes = 0; // No workspace for MoE GEMMs. + static constexpr int split_k_limit = 1; // MoE GEMM does not support split-k. + CutlassGemmConfig chosen_config = + estimate_best_config_from_occupancies(candidate_configs, occupancies, total_rows, gemm_n, gemm_k, num_experts, + split_k_limit, workspace_bytes, multi_processor_count_, is_weight_only); + + dispatch_to_arch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, chosen_config, stream); +} + +template +void MoeGemmRunner::moe_gemm_bias_act(const T* A, const WeightType* B, const T* weight_scales, + const T* biases, T* C, int64_t* total_rows_before_expert, + int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, ActivationType activation_type, + cudaStream_t stream) { + switch (activation_type) { + case ActivationType::Relu: + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); + break; + case ActivationType::Gelu: + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, stream); + break; + case ActivationType::Silu: + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); + break; + case ActivationType::Identity: + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); + break; + case ActivationType::InvalidType: + ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); + break; + default: { + ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); + } + } +} + +template +void MoeGemmRunner::moe_gemm(const T* A, const WeightType* B, const T* weight_scales, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, + int64_t gemm_k, int num_experts, cudaStream_t stream) { + run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); +} + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu new file mode 100644 index 0000000000000..398ce4ee9880f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -0,0 +1,830 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +// Ignore CUTLASS warnings about type punning +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +#include "moe_kernel.h" + +#if CUDA_VERSION >= 11000 +#include +#include +#include +#else +#include "cub/cub.cuh" +#include "cub/device/device_radix_sort.cuh" +#include "cub/util_type.cuh" +#endif + +namespace ort_fastertransformer { + +static constexpr int WARP_SIZE = 32; + +// ====================== Softmax things =============================== +// We have our own implementation of softmax here so we can support transposing the output +// in the softmax kernel when we extend this module to support expert-choice routing. +template +__launch_bounds__(TPB) __global__ + void moe_softmax(const T* input, const bool* finished, T* output, const int num_cols) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + + const int thread_row_offset = blockIdx.x * num_cols; + + cub::Sum sum; + float threadData(-FLT_MAX); + + // Don't touch finished rows. + if ((finished != nullptr) && finished[blockIdx.x]) { + return; + } + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData = max(static_cast(input[idx]), threadData); + } + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) { + float_max = maxElem; + } + __syncthreads(); + + threadData = 0; + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData += exp((static_cast(input[idx]) - float_max)); + } + + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + + if (threadIdx.x == 0) { + normalizing_factor = 1.f / Z; + } + __syncthreads(); + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + output[idx] = T(val); + } +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 +template +__launch_bounds__(TPB) __global__ void moe_top_k(const T*, const bool*, T*, int*, int*, int, const int) { + // Does not support pre-Kepler architectures + ; +} +#else +template +__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, const bool* finished, T* output, + int* indices, int* source_rows, int num_experts, int k) { + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + int num_rows = gridDim.x; + const int block_row = blockIdx.x; + + const bool should_process_row = finished ? !finished[block_row] : true; + const int thread_read_offset = blockIdx.x * num_experts; + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + const int idx = k * block_row + k_idx; + output[idx] = result_kvp.value; + indices[idx] = should_process_row ? result_kvp.key : num_experts; + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + } +} +#endif + +// ====================== TopK softmax things =============================== + +/* + A Top-K gating softmax written to exploit when the number of experts in the MoE layers + are a small power of 2. This allows us to cleanly share the rows among the threads in + a single warp and eliminate communication between warps (so no need to use shared mem). + + It fuses the softmax, max and argmax into a single kernel. + + Limitations: + 1) This implementation is intended for when the number of experts is a small power of 2. + 2) This implementation assumes k is small, but will work for any k. +*/ + +template +__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ + void topk_gating_softmax(const T* input, const bool* finished, T* output, int num_rows, int* indices, + int* source_rows, int k) { + // We begin by enforcing compile time assertions and setting up compile time constants. + static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); + static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); + static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); + static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); + + // Number of bytes each thread pulls in per load + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static constexpr int ELTS_PER_ROW = NUM_EXPERTS; + static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; + static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + + // Restrictions based on previous section. + static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); + static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); + static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); + static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); + + // We have NUM_EXPERTS elements per row. We specialize for small #experts + static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; + static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; + + // Restrictions for previous section. + static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); + + // ===================== From this point, we finally start computing run-time variables. ======================== + + // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. + // This, each block processes a chunk of rows. We start by computing the start row for each block. + const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + + // Now, using the base row per thread block, we compute the base row per warp. + const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + + // The threads in a warp are split into sub-groups that will work on a row. + // We compute row offset for each thread sub-group + const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + const int thread_row = warp_base_row + thread_row_in_warp; + + // Threads with indices out of bounds should early exit here. + if (thread_row >= num_rows) return; + const bool should_process_row = finished ? !finished[thread_row] : true; + + // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the + // row it will read. + const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + + // Now, we compute the group each thread belong to in order to determine the first column to start loads. + const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; + const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + + // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, + // this can support all powers of 2 up to 16. + using AccessType = cutlass::AlignedArray; + + // Finally, we pull in the data from global mem + cutlass::Array row_chunk_input; + AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk_input); + const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + + using ComputeType = float; + using Converter = cutlass::NumericArrayConverter; + Converter compute_type_converter; + cutlass::Array row_chunk = compute_type_converter(row_chunk_input); + + // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just + // convert to float afterwards for the exp + sum reduction. + ComputeType thread_max = row_chunk[0]; +#pragma unroll + for (int ii = 1; ii < VPT; ++ii) { + thread_max = max(thread_max, row_chunk[ii]); + } + +// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); + } + + // From this point, thread max in all the threads have the max within the row. + // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. + float row_sum = 0; +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = expf(row_chunk[ii] - thread_max); + row_sum += row_chunk[ii]; + } + +// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); + } + + // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables + // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to + // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. + // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the + // argmax after computing the softmax. + const float reciprocal_row_sum = 1.f / row_sum; + +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; + } + + // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along + // with the max index.​ + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + for (int k_idx = 0; k_idx < k; ++k_idx) { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; +#pragma unroll + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) { +#pragma unroll + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + + // No check on the experts here since columns with the smallest index are processed first and only + // updated if > (not >=) + if (val > max_val) { + max_val = val; + expert = col + ii; + } + } + } + +// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max. +// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can +// then blank out their max with -inf and the warp can run more iterations... +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); + int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this way + if (other_max > max_val || (other_max == max_val && other_expert < expert)) { + max_val = other_max; + expert = other_expert; + } + } + + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) { + // The lead thread from each sub-group will write out the final results to global memory. (This will be a + // single) thread per row of the input/output matrices. + const int idx = k * thread_row + k_idx; + output[idx] = T(max_val); + indices[idx] = should_process_row ? expert : NUM_EXPERTS; + source_rows[idx] = k_idx * num_rows + thread_row; + } + + // Finally, we clear the value in the thread with the current max if there is another iteration to run. + if (k_idx + 1 < k) { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = ComputeType(-10000.f); + } + } + } +} + +namespace detail { +// Constructs some constants needed to partition the work across threads at compile time. +template +struct TopkConstants { + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); + static constexpr int VECs_PER_THREAD = std::max(1, (int)EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +}; +} // namespace detail + +template +void topk_gating_softmax_launcher_helper(const T* input, const bool* finished, T* output, int* indices, int* source_row, + int num_rows, int num_experts, int k, cudaStream_t stream) { + static constexpr unsigned long MAX_BYTES_PER_LDG = 16; + + static constexpr int BYTES_PER_LDG = std::min((int)MAX_BYTES_PER_LDG, (int)sizeof(T) * EXPERTS); + using Constants = detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + + dim3 block_dim(WARP_SIZE, WARPS_PER_TB); + topk_gating_softmax + <<>>(input, finished, output, num_rows, indices, source_row, k); +} + +template +void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_output, + int* indices, int* source_row, int num_rows, int num_experts, + int k, cudaStream_t stream) { + static constexpr int WARPS_PER_TB = 4; + + switch (num_experts) { + case 2: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 4: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 8: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 16: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 32: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 64: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 128: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 256: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + default: { + static constexpr int TPB = 256; + moe_softmax<<>>(input, finished, softmax_temp_output, num_experts); + moe_top_k + <<>>(softmax_temp_output, finished, output, indices, source_row, num_experts, k); + } + } +} + +// ========================== CUB Sorting things ==================================== +CubKeyValueSorter::CubKeyValueSorter() : num_experts_(0), num_bits_(sizeof(int) * 8) {} + +CubKeyValueSorter::CubKeyValueSorter(int num_experts) + : num_experts_(num_experts), num_bits_((int)log2(num_experts) + 1) {} + +void CubKeyValueSorter::update_num_experts(int num_experts) { + num_experts_ = num_experts; + num_bits_ = (int)log2(num_experts) + 1; +} + +size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs) { + num_key_value_pairs_ = num_key_value_pairs; + size_t required_storage = 0; + int* null_int = nullptr; + cub::DeviceRadixSort::SortPairs(NULL, required_storage, null_int, null_int, null_int, null_int, + (int)num_key_value_pairs, 0, num_bits_); + return required_storage; +} + +void CubKeyValueSorter::run(void* workspace, const size_t workspace_size, const int* keys_in, int* keys_out, + const int* values_in, int* values_out, const size_t num_key_value_pairs, + cudaStream_t stream) { + size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs); + size_t actual_ws_size = workspace_size; + + if (expected_ws_size > workspace_size) { + ORT_THROW("Error. The allocated workspace is too small to run this problem. Expected workspace size of at least ", + expected_ws_size, " but got problem size ", workspace_size, "\n"); + } + cub::DeviceRadixSort::SortPairs(workspace, actual_ws_size, keys_in, keys_out, values_in, values_out, + (int)num_key_value_pairs, 0, num_bits_, stream); +} + +// ============================== Infer GEMM sizes ================================= +__device__ inline int find_total_elts_leq_target(const int* sorted_indices, const int arr_length, const int target) { + int64_t low = 0, high = arr_length - 1, target_location = -1; + while (low <= high) { + int64_t mid = (low + high) / 2; + + if (sorted_indices[mid] > target) { + high = mid - 1; + } else { + low = mid + 1; + target_location = mid; + } + } + return target_location + 1; +} + +// Sets up the gemm assuming the inputs, experts and outputs are stored in row major order. +// Assumes we want to perform output = matmul(inputs, experts) + bias +__global__ void compute_total_rows_before_expert_kernel(const int* sorted_experts, const int sorted_experts_len, + const int64_t num_experts, int64_t* total_rows_before_expert) { + // First, compute the global tid. We only need 1 thread per expert. + const int expert = blockIdx.x * blockDim.x + threadIdx.x; + if (expert >= num_experts) return; + + // This should construct the last index where each expert occurs. + total_rows_before_expert[expert] = find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); +} + +template +CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version) { + moe_gemm_runner_.initialize(sm_version); +} + +template +size_t CutlassMoeFCRunner::getWorkspaceSize(int num_rows, const int hidden_size, + const int inter_size, int num_experts, + int k) { + const int buf_size = static_cast(pad_to_multiple_of_16(k * num_rows * hidden_size)); + const int interbuf_size = static_cast(pad_to_multiple_of_16(k * num_rows * inter_size)); + const int padded_experts = static_cast(pad_to_multiple_of_16(num_experts)); + const int num_moe_inputs = static_cast(pad_to_multiple_of_16(k * num_rows)); + int num_softmax_outs = 0; + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if (!is_pow_2 || num_experts > 256) { + num_softmax_outs = static_cast(pad_to_multiple_of_16(num_rows * num_experts)); + } + + // softmax output, permuted_rows and permuted_experts have moved to outside of moe kernel, allocate them + // in Encoder or Decoder before invoking FfnLayer forward. + size_t total_ws_bytes = 3 * num_moe_inputs * sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ + total_ws_bytes += buf_size * sizeof(T); // permuted_data + total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ + total_ws_bytes += num_softmax_outs * sizeof(T); + const int bytes_for_fc1_result = interbuf_size * sizeof(T); + const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows))); + sorter_.update_num_experts(num_experts); + + int bytes_for_intermediate_and_sorting = bytes_for_fc1_result; + if (sorter_ws_size_bytes > bytes_for_fc1_result) { + int remaining_bytes = static_cast(pad_to_multiple_of_16(sorter_ws_size_bytes - bytes_for_fc1_result)); + bytes_for_intermediate_and_sorting += remaining_bytes; + } + + total_ws_bytes += bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub sorting workspace + return total_ws_bytes; +} + +template +void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, int num_rows, + const int hidden_size, const int inter_size, + int num_experts, int k) { + const int buf_size = static_cast(pad_to_multiple_of_16(k * num_rows * hidden_size)); + const int interbuf_size = static_cast(pad_to_multiple_of_16(k * num_rows * inter_size)); + const int padded_experts = static_cast(pad_to_multiple_of_16(num_experts)); + const int num_moe_inputs = static_cast(pad_to_multiple_of_16(k * num_rows)); + // const int num_softmax_outs = pad_to_multiple_of_16(num_rows * num_experts); + + source_rows_ = (int*)ws_ptr; + permuted_rows_ = source_rows_ + num_moe_inputs; + permuted_experts_ = permuted_rows_ + num_moe_inputs; + permuted_data_ = (T*)(permuted_experts_ + num_moe_inputs); + + total_rows_before_expert_ = (int64_t*)(permuted_data_ + buf_size); + + fc1_result_ = (T*)(total_rows_before_expert_ + padded_experts); + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if (!is_pow_2 || num_experts > 256) { + softmax_out_ = (T*)(fc1_result_ + interbuf_size); + } else { + softmax_out_ = nullptr; + } +} + +template +void CutlassMoeFCRunner::run_moe_fc( + const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, + const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights, + const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, + int k, char* workspace_ptr, T* fc2_result, const bool* finished, int active_rows, T* expert_scales, + int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream) { + static constexpr bool scales_required = + std::is_same::value || std::is_same::value; + + if (scales_required) { + if (fc1_scales == nullptr) { + ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for first matmul is a null pointer"); + } else if (fc2_scales == nullptr) { + ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for second matmul is a null pointer"); + } + } else { + if (fc1_scales != nullptr) { + ORT_THROW("[FT Error][Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC1"); + } else if (fc2_scales != nullptr) { + ORT_THROW("[FT Error][Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC2"); + } + } + + configure_ws_ptrs(workspace_ptr, num_rows, hidden_size, inter_size, num_experts, k); + topk_gating_softmax_kernelLauncher(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row, + source_rows_, num_rows, num_experts, k, stream); + + const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows))); + sorter_.run((void*)fc1_result_, sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, source_rows_, + permuted_rows_, k * num_rows, stream); + + initialize_moe_routing_kernelLauncher(input_activations, permuted_data_, permuted_rows_, + expanded_source_row_to_expanded_dest_row, num_rows, active_rows, hidden_size, k, + stream); + + const int expanded_active_expert_rows = k * active_rows; + compute_total_rows_before_expert(permuted_experts_, expanded_active_expert_rows, num_experts, + total_rows_before_expert_, stream); + + moe_gemm_runner_.moe_gemm_bias_act(permuted_data_, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_result_, + total_rows_before_expert_, expanded_active_expert_rows, inter_size, hidden_size, + num_experts, fc1_activation_type, stream); + + moe_gemm_runner_.moe_gemm(fc1_result_, fc2_expert_weights, fc2_scales, fc2_result, total_rows_before_expert_, + expanded_active_expert_rows, hidden_size, inter_size, num_experts, stream); +} + +template +void CutlassMoeFCRunner::run_moe_fc( + const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, + const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights, + const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, + int k, char* workspace_ptr, T* fc2_result, T* expert_scales, int* expanded_source_row_to_expanded_dest_row, + int* expert_for_source_row, cudaStream_t stream) { + run_moe_fc(input_activations, gating_output, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_activation_type, + fc2_expert_weights, fc2_scales, num_rows, hidden_size, inter_size, num_experts, k, workspace_ptr, + fc2_result, nullptr, num_rows, expert_scales, expanded_source_row_to_expanded_dest_row, + expert_for_source_row, stream); +} + +template +void CutlassMoeFCRunner::compute_total_rows_before_expert(const int* sorted_indices, + const int total_indices, + int num_experts, + int64_t* total_rows_before_expert, + cudaStream_t stream) { + const int threads = std::min(1024, num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + compute_total_rows_before_expert_kernel<<>>(sorted_indices, total_indices, num_experts, + total_rows_before_expert); +} + +// ========================== Permutation things ======================================= + +// Duplicated and permutes rows for MoE. In addition, reverse the permutation map to help with finalizing routing. + +// "expanded_x_row" simply means that the number of values is num_rows x k. It is "expanded" since we will have to +// duplicate some rows in the input matrix to match the dimensions. Duplicates will always get routed to separate +// experts in the end. + +// Note that the expanded_dest_row_to_expanded_source_row map referred to here has indices in the range (0, +// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input +// all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we simply take the modulus +// of the expanded index. + +template +__global__ void initialize_moe_routing_kernel(const T* unpermuted_input, T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, int num_rows, + int active_rows, int cols) { + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the + // reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. + const int expanded_dest_row = blockIdx.x; + const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + if (threadIdx.x == 0) { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row; + } + + if (blockIdx.x < active_rows) { + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* dest_row_ptr = permuted_output + expanded_dest_row * cols; + + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + dest_row_ptr[tid] = source_row_ptr[tid]; + } + } +} + +template +void initialize_moe_routing_kernelLauncher(const T* unpermuted_input, T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, int num_rows, + int active_rows, int cols, int k, cudaStream_t stream) { + const int blocks = num_rows * k; + const int threads = std::min(cols, 1024); + initialize_moe_routing_kernel + <<>>(unpermuted_input, permuted_output, expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, num_rows, k * active_rows, cols); +} + +// Final kernel to unpermute and scale +// This kernel unpermutes the original data, does the k-way reduction and performs the final skip connection. +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 +template +__global__ void finalize_moe_routing_kernel(const T*, T*, const T*, const T*, const T*, const T*, const int*, + const int*, int, const int) { + // Does not support pre-Kepler architectures + ; +} +#else +template +__global__ void finalize_moe_routing_kernel(const T* expanded_permuted_rows, T* reduced_unpermuted_output, + const T* skip_1, const T* skip_2, const T* bias, const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int cols, int k) { + const int original_row = blockIdx.x; + int num_rows = gridDim.x; + T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols; + + const T* skip_1_row_ptr = nullptr; + if (RESIDUAL_NUM == 1) { + skip_1_row_ptr = skip_1 + original_row * cols; + } + const T* skip_2_row_ptr = nullptr; + if (RESIDUAL_NUM == 2) { + skip_2_row_ptr = skip_2 + original_row * cols; + } + + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + T thread_output; + if (RESIDUAL_NUM == 0) { + thread_output = T(0); + } else if (RESIDUAL_NUM == 1) { + thread_output = skip_1_row_ptr[tid]; + } else if (RESIDUAL_NUM == 2) { + thread_output = skip_1_row_ptr[tid] + skip_2_row_ptr[tid]; + } + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int expanded_original_row = original_row + k_idx * num_rows; + const int expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row]; + + const int64_t k_offset = original_row * k + k_idx; + const T row_scale = scales[k_offset]; + const T* expanded_permuted_rows_row_ptr = expanded_permuted_rows + expanded_permuted_row * cols; + + const int expert_idx = expert_for_source_row[k_offset]; + const T* bias_ptr = bias + expert_idx * cols; + + thread_output = thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] + bias_ptr[tid]); + } + reduced_row_ptr[tid] = thread_output; + } +} +#endif + +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* bias, + const T* scales, const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, + int k, cudaStream_t stream) { + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + finalize_moe_routing_kernel<<>>( + expanded_permuted_rows, reduced_unpermuted_output, nullptr, nullptr, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); +} + +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip, + const T* bias, const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, + int k, cudaStream_t stream) { + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + finalize_moe_routing_kernel + <<>>(expanded_permuted_rows, reduced_unpermuted_output, skip, nullptr, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); +} + +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1, + const T* skip_2, const T* bias, const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, + int k, cudaStream_t stream) { + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + if (skip_2 == nullptr) { + finalize_moe_routing_kernel<<>>( + expanded_permuted_rows, reduced_unpermuted_output, skip_1, skip_2, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); + } else { + finalize_moe_routing_kernel<<>>( + expanded_permuted_rows, reduced_unpermuted_output, skip_1, skip_2, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); + } +} + +// ========================= TopK Softmax specializations =========================== +template void topk_gating_softmax_kernelLauncher(const float*, const bool*, float*, float*, int*, int*, int, + int, int, cudaStream_t); +template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int, + int, int, cudaStream_t); + +// ==================== Variable batched GEMM specializations ================================== +template class CutlassMoeFCRunner; +template class CutlassMoeFCRunner; + +// ===================== Specializations for init routing ========================= +template void initialize_moe_routing_kernelLauncher(const float*, float*, const int*, int*, int, int, + int, int, cudaStream_t); +template void initialize_moe_routing_kernelLauncher(const half*, half*, const int*, int*, int, int, + int, int, cudaStream_t); + +// ==================== Specializations for final routing =================================== +template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const int*, + const int*, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const int*, const int*, + int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const float*, + const int*, const int*, int, int, int, + cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, const int*, + const int*, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const float*, + const float*, const int*, const int*, int, int, int, + cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, + const half*, const int*, const int*, int, int, int, + cudaStream_t); + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h new file mode 100644 index 0000000000000..5cefe4fa5dc47 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "moe_gemm_kernels.h" +#include + +#include "core/common/common.h" + +using namespace onnxruntime; + +namespace ort_fastertransformer { + +static inline size_t pad_to_multiple_of_16(size_t input) { + static constexpr int ALIGNMENT = 16; + return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +} + +/* + Launches the topk gating softmax required for the MoE layers. + + Params: + input - a [num_rows x num_experts] + finished - [num_rows] vector with 1 if the sentence at this row is done translating and 0 otherwise. + output - a buffer of shape [num_rows x k] containing the top-k values of the softmax for each row. + indices - a matrix of shape [num_rows x k] containing the top-k experts each row should get routed to. + source_rows - a matrix of shape [num_rows x k] used internally for permuting. source_rows[row][k] = k * num_rows + + row. It is constructed like this so we can track where each of the original rows end up in order to perform the + "k-way" reduction later in the routing. + + num_rows - The number of rows in the matrix + num_experts - The number of expert layers present + k - k value in topk +*/ +template +void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_out, + int* indices, int* source_row, int num_rows, int num_experts, + int k, cudaStream_t stream); + +class CubKeyValueSorter { + public: + CubKeyValueSorter(); + + CubKeyValueSorter(int num_experts); + + void update_num_experts(int num_experts); + + size_t getWorkspaceSize(const size_t num_key_value_pairs); + + void run(void* workspace, const size_t workspace_size, const int* keys_in, int* keys_out, const int* values_in, + int* values_out, const size_t num_key_value_pairs, cudaStream_t stream); + + private: + size_t num_key_value_pairs_; + int num_experts_; + int num_bits_; +}; + +template +void initialize_moe_routing_kernelLauncher(const T* unpermuted_input, T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, int num_rows, + int active_rows, int cols, int k, cudaStream_t stream); + +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* bias, + const T* scales, const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, + int k, cudaStream_t stream); + +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip, + const T* bias, const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, + int k, cudaStream_t stream); + +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1, + const T* skip_2, const T* bias, const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, + int k, cudaStream_t stream); + +// Assumes inputs activations are row major. Weights need to be preprocessed by th_op/weight_quantize.cc . +// Nested in a class to avoid multiple calls to cudaGetDeviceProperties as this call can be expensive. +// Avoid making several duplicates of this class. +template +class CutlassMoeFCRunner { + public: + CutlassMoeFCRunner(int sm_version); + + size_t getWorkspaceSize(int num_rows, int hidden_size, int inter_size, int num_experts, int k); + + void run_moe_fc(const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, + const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type, + const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, int hidden_size, + int inter_size, int num_experts, int k, char* workspace_ptr, T* fc2_result, + T* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, + cudaStream_t stream); + + void run_moe_fc(const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, + const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type, + const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, int hidden_size, + int inter_size, int num_experts, int k, char* workspace_ptr, T* fc2_result, + const bool* finished, int active_rows, T* expert_scales, + int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream); + + void compute_total_rows_before_expert(const int* sorted_indices, int total_indices, int num_experts, + int64_t* total_rows_before_expert, cudaStream_t stream); + + private: + void configure_ws_ptrs(char* ws_ptr, int num_rows, int hidden_size, int inter_size, int num_experts, int k); + + private: + CubKeyValueSorter sorter_; + MoeGemmRunner moe_gemm_runner_; + + // Pointers + int* source_rows_; + int* permuted_rows_; + int* permuted_experts_; + char* sorter_ws_; + T* permuted_data_; + T* softmax_out_; + + int64_t* total_rows_before_expert_; + + T* fc1_result_; +}; + +template +class CutlassMoeFCRunner::value>> { + public: + CutlassMoeFCRunner(int sm_version); + + size_t getWorkspaceSize(int num_rows, int hidden_size, int inter_size, int num_experts, int k) { + return 0; + } +}; + +} // namespace ort_fastertransformer \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h new file mode 100644 index 0000000000000..00f977c615df6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h @@ -0,0 +1,290 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Base scheduler for grouped problems, using MoE +*/ + +#pragma once + +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct BaseMoeProblemVisitor { + using ThreadblockShape = ThreadblockShape_; + + struct ProblemInfo { + static int32_t const kNoPrefetchEntry = -1; + int32_t problem_idx; + int32_t problem_start; + + CUTLASS_DEVICE + ProblemInfo() : problem_idx(kNoPrefetchEntry), problem_start(kNoPrefetchEntry) {} + + CUTLASS_DEVICE + ProblemInfo(int32_t problem_idx_, int32_t problem_start_) + : problem_idx(problem_idx_), problem_start(problem_start_) {} + }; + + struct Params { + int64_t const* last_row_for_problem; + int64_t gemm_n; + int64_t gemm_k; + int32_t problem_count; + void const* workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params() + : last_row_for_problem(nullptr), gemm_n(0), gemm_k(0), problem_count(0), workspace(nullptr), tile_count(0) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Params(int64_t const* last_row_for_problem, int64_t gemm_n, int64_t gemm_k, int32_t problem_count, + void const* workspace = nullptr, int32_t tile_count = 0) + : last_row_for_problem(last_row_for_problem), + gemm_n(gemm_n), + gemm_k(gemm_k), + problem_count(problem_count), + workspace(workspace), + tile_count(tile_count) {} + }; + + Params const& params; + int32_t tile_idx; + int32_t problem_tile_start; + int32_t problem_idx; + + // + // Methods + // + CUTLASS_DEVICE + BaseMoeProblemVisitor(Params const& params_, int32_t block_idx) + : params(params_), tile_idx(block_idx), problem_tile_start(0), problem_idx(0) {} + + /// Get the grid shape + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { + return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1); + } + + /// Gets the global tile index + CUTLASS_HOST_DEVICE + int32_t tile_index() const { return tile_idx; } + + /// Gets the index of the problem + CUTLASS_HOST_DEVICE + int32_t problem_index() const { return problem_idx; } + + CUTLASS_HOST_DEVICE + int32_t threadblock_idx() const { return tile_idx - problem_tile_start; } + + CUTLASS_DEVICE + void advance(int32_t grid_size) { tile_idx += grid_size; } + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) { + ProblemSizeHelper::possibly_transpose_problem(problem); + } + + /// Returns the problem size for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size() const { return problem_size(problem_idx); } + + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size(int idx) const { + const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; + const int64_t current_problem_row = params.last_row_for_problem[idx]; + const int64_t gemm_m = current_problem_row - prev_problem_row; + GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k)); + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + CUTLASS_HOST_DEVICE + static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { return ProblemSizeHelper::tile_count(grid); } + + static int32_t group_tile_count(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count) { + int32_t total_tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) { + auto problem = host_problem_sizes_ptr[i]; + possibly_transpose_problem(problem); + auto grid = grid_shape(problem); + total_tiles += tile_count(grid); + } + + return total_tiles; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MoeProblemVisitor; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// ProblemVisitor that performs all scheduling on device +// +template +struct MoeProblemVisitor : public BaseMoeProblemVisitor { + using Base = BaseMoeProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + static bool const kRequiresPrecomputation = false; + static int const kThreadsPerWarp = 32; + + struct SharedStorage {}; + + // Final tile of the problem loaded by this thread. Each thread will hold + // a separate value. + int32_t problem_ending_tile; + + SharedStorage& shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) + : Base(params_, block_idx), problem_ending_tile(0), shared_storage(shared_storage_) { + this->problem_idx = -1 * kThreadsPerWarp; + this->problem_tile_start = 0; + } + + CUTLASS_DEVICE + bool next_tile() { + // Check whether the tile to compute is within the range of the current problem. + int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp); + if (this->tile_idx < problem_tile_end) { + return true; + } + + // Check whether the tile to compute is within the current group of problems fetched by the warp. + // The last tile for this group is the final tile of the problem held by the final thread in the warp. + int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); + + // Keep the starting problem for this group in `problem_idx`. This is done to reduce + // register pressure. The starting problem for this group is simply the first problem + // in the group most recently fetched by the warp. + int32_t& group_problem_start = this->problem_idx; + group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp; + + // Keep the starting tile for this group in `problem_tile_start`. This is done to reduce + // register pressure. + int32_t& group_tile_start = this->problem_tile_start; + + // Each thread in the warp processes a separate problem to advance until + // reaching a problem whose starting tile is less less than tile_idx. + while (group_tile_end <= this->tile_idx) { + group_problem_start += kThreadsPerWarp; + if (group_problem_start > this->params.problem_count) { + return false; + } + + // Since `group_tile_start` is a reference to `this->problem_tile_start`, this + // also sets `this->problem_tile_start`. The fact that `this->problem_tile_start` + // is also set here is used later in `next_tile`. + group_tile_start = group_tile_end; + + int lane_idx = threadIdx.x % kThreadsPerWarp; + int32_t lane_problem = group_problem_start + lane_idx; + + // Compute the number of tiles in the problem assigned to each thread. + problem_ending_tile = 0; + if (lane_problem < this->params.problem_count) { + cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + problem_ending_tile = this->tile_count(grid); + } + + // Compute a warp-wide inclusive prefix sum to compute the ending tile index of + // each thread's problem. + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kThreadsPerWarp; i <<= 1) { + int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i); + if (lane_idx >= i) { + problem_ending_tile += val; + } + } + + // The total tile count for this group is now in the final position of the prefix sum + int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); + + problem_ending_tile += group_tile_start; + group_tile_end += tiles_in_group; + } + + // The next problem to process is the first one that does not have ending tile position + // that is greater than or equal to tile index. + int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx)); + + this->problem_idx = group_problem_start + problem_idx_in_group; + + // The starting tile for this problem is the ending tile of the previous problem. In cases + // where `problem_idx_in_group` is the first problem in the group, we do not need to reset + // `problem_tile_start`, because it is set to the previous group's ending tile in the while + // loop above. + if (problem_idx_in_group > 0) { + this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1); + } + + return true; + } + + static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count, + int32_t block_count) { + return 0; + } + + static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count, + int32_t block_count, void* host_workspace_ptr) {} +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h new file mode 100644 index 0000000000000..3505bea24e4d9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h @@ -0,0 +1,61 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines new layouts needed for MoE +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/pitch_linear_coord.h" + +namespace cutlass { +namespace layout { + +template +class ColumnMajorTileInterleave { + static constexpr int kRowsPerTile = RowsPerTile; + static constexpr int kColumnsInterleaved = ColumnsInterleaved; +}; + +template +struct IsColumnMajorTileInterleave { + static constexpr bool value = false; +}; + +template +struct IsColumnMajorTileInterleave> { + static constexpr bool value = true; +}; + +} // namespace layout +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc new file mode 100644 index 0000000000000..6f2ffe7a0cc43 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -0,0 +1,197 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_common.h" +#include "moe.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + MoE, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(0, 0) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + MoE); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +using namespace ONNX_NAMESPACE; + +template +Status MoE::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* router_probs = context->Input(1); + const Tensor* fc1_experts_weights = context->Input(2); + const Tensor* fc2_experts_weights = context->Input(3); + const Tensor* fc1_experts_bias_optional = context->Input(4); + const Tensor* fc2_experts_bias_optional = context->Input(5); + + const auto& input_dims = input->Shape().GetDims(); + const auto& router_probs_dims = router_probs->Shape().GetDims(); + const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); + const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims(); + + const int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1]; + const int64_t hidden_size = input_dims[input_dims.size() - 1]; + const int64_t num_experts = fc1_experts_weights_dims[0]; + const int64_t inter_size = fc1_experts_weights_dims[2]; + + // TODO: refactor to helper function. + if (fc1_experts_weights_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims must be 3D, got ", + fc1_experts_weights_dims.size()); + } + if (fc2_experts_weights_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_weights_dims must be 3D, got ", + fc2_experts_weights_dims.size()); + } + if (fc1_experts_weights_dims[1] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_weights_dims[1] must be equal to hidden_size, got ", + fc1_experts_weights_dims[1], " and ", hidden_size); + } + if (fc2_experts_weights_dims[1] != inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_weights_dims[1] must be equal to inter_size, got ", fc2_experts_weights_dims[1], + " and ", inter_size); + } + if (fc1_experts_weights_dims[2] != inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_weights_dims[2] must be equal to inter_size, got ", fc1_experts_weights_dims[2], + " and ", inter_size); + } + if (fc2_experts_weights_dims[2] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_weights_dims[2] must be equal to hidden_size, got ", + fc2_experts_weights_dims[2], " and ", hidden_size); + } + if (router_probs_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims must be 2D, got ", + router_probs_dims.size()); + } + if (router_probs_dims[0] != num_rows) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[0] must be equal to num_rows, got ", + router_probs_dims[0], " and ", num_rows); + } + if (router_probs_dims[1] != num_experts) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[1] must be equal to num_experts, got ", + router_probs_dims[1], " and ", num_experts); + } + if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is set but fc2_experts_bias is not set"); + } + if (fc1_experts_bias_optional == nullptr && fc2_experts_bias_optional != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is not set but fc2_experts_bias is set"); + } + if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional != nullptr) { + const auto& fc1_experts_bias_dims = fc1_experts_bias_optional->Shape().GetDims(); + const auto& fc2_experts_bias_dims = fc2_experts_bias_optional->Shape().GetDims(); + if (fc1_experts_bias_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims must be 2D, got ", + fc1_experts_bias_dims.size()); + } + if (fc2_experts_bias_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims must be 2D, got ", + fc2_experts_bias_dims.size()); + } + if (fc1_experts_bias_dims[0] != num_experts) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_bias_dims[0] must be equal to num_experts, got ", fc1_experts_bias_dims[0], + " and ", num_experts); + } + if (fc2_experts_bias_dims[0] != num_experts) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_bias_dims[0] must be equal to num_experts, got ", fc2_experts_bias_dims[0], + " and ", num_experts); + } + if (fc1_experts_bias_dims[1] != inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_bias_dims[1] must be equal to inter_size, got ", fc1_experts_bias_dims[1], + " and ", inter_size); + } + if (fc2_experts_bias_dims[1] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_bias_dims[1] must be equal to hidden_size, got ", fc2_experts_bias_dims[1], + " and ", hidden_size); + } + } + + typedef typename ToCudaType::MappedType CudaT; + auto stream = context->GetComputeStream(); + + auto& device_prop = GetDeviceProp(); + const int sm = device_prop.major * 10 + device_prop.minor; + + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm); + + size_t ws_size = + moe_runner.getWorkspaceSize(static_cast(num_rows), static_cast(hidden_size), + static_cast(inter_size), static_cast(num_experts), static_cast(k_)); + size_t fc2_output_size = k_ * num_rows * hidden_size * sizeof(CudaT); + size_t expert_scales_size = k_ * num_rows * sizeof(CudaT); + size_t expanded_source_row_to_expanded_dest_row_size = k_ * num_rows * sizeof(int); + size_t expert_for_source_row_size = k_ * num_rows * sizeof(int); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + // TODO: allocate one buffer and reuse it. + IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, ws_size, false, stream); + IAllocatorUniquePtr fc2_output = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, stream); + IAllocatorUniquePtr expert_scales = + IAllocator::MakeUniquePtr(allocator, expert_scales_size, false, stream); + IAllocatorUniquePtr expanded_source_row_to_expanded_dest_row = + IAllocator::MakeUniquePtr(allocator, expanded_source_row_to_expanded_dest_row_size, false, stream); + IAllocatorUniquePtr expert_for_source_row = + IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream); + + // fc1_scales and fc2_scales are used in quantized MoE + const CudaT* fc1_scales_ptr = nullptr; + const CudaT* fc2_scales_ptr = nullptr; + + moe_runner.run_moe_fc(reinterpret_cast(input->template Data()), + reinterpret_cast(router_probs->template Data()), + reinterpret_cast(fc1_experts_weights->template Data()), + std::move(fc1_scales_ptr), + fc1_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc1_experts_bias_optional->template Data()), + activation_type_, reinterpret_cast(fc2_experts_weights->template Data()), + std::move(fc2_scales_ptr), static_cast(num_rows), static_cast(hidden_size), + static_cast(inter_size), static_cast(num_experts), static_cast(k_), + reinterpret_cast(work_space.get()), reinterpret_cast(fc2_output.get()), + reinterpret_cast(expert_scales.get()), + reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), + reinterpret_cast(expert_for_source_row.get()), Stream(context)); + + Tensor* output = context->Output(0, input->Shape()); + + ort_fastertransformer::finalize_moe_routing_kernelLauncher( + reinterpret_cast(fc2_output.get()), reinterpret_cast(output->template MutableData()), + fc2_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc2_experts_bias_optional->template Data()), + reinterpret_cast(expert_scales.get()), + reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), + reinterpret_cast(expert_for_source_row.get()), static_cast(num_rows), static_cast(hidden_size), + static_cast(k_), Stream(context)); + + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.h b/onnxruntime/contrib_ops/cuda/moe/moe.h new file mode 100644 index 0000000000000..8035568693814 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/moe.h @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class MoE final : public CudaKernel { + public: + explicit MoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("k", &k_).IsOK()); + + std::string activation_type_str; + ORT_ENFORCE(op_kernel_info.GetAttr("activation_type", &activation_type_str).IsOK()); + if (activation_type_str == "relu") { + activation_type_ = ort_fastertransformer::ActivationType::Relu; + } else if (activation_type_str == "gelu") { + activation_type_ = ort_fastertransformer::ActivationType::Gelu; + } else if (activation_type_str == "silu") { + activation_type_ = ort_fastertransformer::ActivationType::Silu; + } else if (activation_type_str == "identity") { + activation_type_ = ort_fastertransformer::ActivationType::Identity; + } else { + ORT_THROW("Unsupported MoE activation type: ", activation_type_str); + } + } + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + int64_t k_; + ort_fastertransformer::ActivationType activation_type_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu index e58723f0b31e1..2f74dd41f0759 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu @@ -35,6 +35,8 @@ template Status SetBnbQuantMap(int quant_type, float* quant_map_buffer, c template Status SetBnbQuantMap(int quant_type, half* quant_map_buffer, cudaStream_t stream); +template Status SetBnbQuantMap(int quant_type, BFloat16* quant_map_buffer, cudaStream_t stream); + template __global__ void kDequantizeBlockwise( const T* quant_map, @@ -62,22 +64,15 @@ __global__ void kDequantizeBlockwise( valid_items_load = (n + 1) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1) / 2 - i; valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2; - local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]); + local_abs_max = absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]; __syncthreads(); LoadChar(loadchar).Load(&(quant_data[i]), qvals, valid_items_load, 128); #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) { - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 - vals[j * 2] = quant_map[qvals[j] >> 4] * local_abs_max; - vals[j * 2 + 1] = quant_map[qvals[j] & 0x0F] * local_abs_max; - #else - // half multiplication not supported - vals[j * 2] = static_cast(static_cast(quant_map[qvals[j] >> 4]) * static_cast(local_abs_max)); - vals[j * 2 + 1] = - static_cast(static_cast(quant_map[qvals[j] & 0x0F]) * static_cast(local_abs_max)); - #endif + vals[j * 2] = ScalarMul(quant_map[qvals[j] >> 4], local_abs_max); + vals[j * 2 + 1] = ScalarMul(quant_map[qvals[j] & 0x0F], local_abs_max); } __syncthreads(); @@ -86,7 +81,7 @@ __global__ void kDequantizeBlockwise( } template -Status DequantizeBnb4( +void CallkDequantizeBlockwise( const T* quant_map, T* output, const uint8_t* quant_data, @@ -102,6 +97,18 @@ Status DequantizeBnb4( absmax, block_size / 2, numel); +} + +template +Status DequantizeBnb4( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + int block_size, + int numel, + cudaStream_t stream) { + CallkDequantizeBlockwise(quant_map, output, quant_data, absmax, block_size, numel, stream); return Status::OK(); } @@ -119,11 +126,36 @@ template Status DequantizeBnb4( const half* quant_map, half* output, const uint8_t* quant_data, - const half *absmax, + const half* absmax, int block_size, int numel, cudaStream_t stream); +template <> +Status DequantizeBnb4( + const BFloat16* quant_map, + BFloat16* output, + const uint8_t* quant_data, + const BFloat16* absmax, + int block_size, + int numel, + cudaStream_t stream) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + CallkDequantizeBlockwise( + reinterpret_cast(quant_map), + reinterpret_cast(output), + quant_data, + reinterpret_cast(absmax), + block_size, + numel, + stream); + #else + CallkDequantizeBlockwise(quant_map, output, quant_data, absmax, block_size, numel, stream); + #endif + + return Status::OK(); +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh index 4aef3ab699f9c..a0d38c9853cd6 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh @@ -11,6 +11,38 @@ namespace cuda { template Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream); +// templated scalar multiply function +template +__device__ inline T ScalarMul(T a, T b); + +template <> +__device__ inline float ScalarMul(float a, float b) { + return a * b; +} + +template <> +__device__ inline half ScalarMul(half a, half b) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + return a * b; + #else + // half multiplication not supported + return static_cast(static_cast(a) * static_cast(b)); + #endif +} + +template <> +__device__ inline BFloat16 ScalarMul(BFloat16 a, BFloat16 b) { + return a * b; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// will use the native bfloat16 multiply instruction on sm_80+ +template <> +__device__ inline nv_bfloat16 ScalarMul(nv_bfloat16 a, nv_bfloat16 b) { + return a * b; +} +#endif + template Status DequantizeBnb4( const T* quant_map, diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc index ecf332715d470..bbcb7de99781f 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -145,6 +145,17 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( .TypeConstraint("T2", DataTypeImpl::GetTensorType()), MatMulBnb4); +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + BFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu index 1d9aa75ff3701..098e3618beddd 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu @@ -6,12 +6,44 @@ #include #include #include +#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh" #include "matmul_bnb4.cuh" namespace onnxruntime { namespace contrib { namespace cuda { +template +__device__ inline float ScalarMulFloatOut(T a, T b); + +template <> +__device__ inline float ScalarMulFloatOut(float a, float b) { + return a * b; +} + +template <> +__device__ inline float ScalarMulFloatOut(half a, half b) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + return static_cast(a * b); + #else + // half multiplication not supported + return static_cast(a) * static_cast(b); + #endif +} + +template <> +__device__ inline float ScalarMulFloatOut(BFloat16 a, BFloat16 b) { + return a * b; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// will use the native bfloat16 multiply instruction on sm_80+ +template <> +__device__ inline float ScalarMulFloatOut(nv_bfloat16 a, nv_bfloat16 b) { + return static_cast(a * b); +} +#endif + #define num_values_4bit 32 template __global__ void kgemm_4bit_inference_naive( @@ -55,7 +87,7 @@ __global__ void kgemm_4bit_inference_naive( int inner_idx_halved = inner_idx / 2; int offset_B = ldb * row_B; int absidx = ((2 * offset_B) + inner_idx) / block_size; - local_absmax = __ldg(&(absmax[absidx])); + local_absmax = absmax[absidx]; if (row_B < N) { if ((inner_idx_halved + num_values_8bit) < (K / 2)) { @@ -78,18 +110,8 @@ __global__ void kgemm_4bit_inference_naive( for (int i = 0; i < 4; i++) { #pragma unroll for (int k = 0; k < num_values_8bit / 4; k++) { - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 - local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; - local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; - #else - // half multiplication not supported - local_B[k * 2] = - static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4]) * - static_cast(local_absmax)); - local_B[k * 2 + 1] = - static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F]) * - static_cast(local_absmax)); - #endif + local_B[k * 2] = ScalarMul(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4], local_absmax); + local_B[k * 2 + 1] = ScalarMul(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F], local_absmax); } if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { @@ -116,12 +138,7 @@ __global__ void kgemm_4bit_inference_naive( // accumulate in float; small performance hit for Ampere, but lower error for outputs #pragma unroll for (int k = 0; k < num_values_4bit / 4; k++) { - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 - local_C += static_cast(local_A[k] * local_B[k]); - #else - // half multiplication not supported - local_C += static_cast(local_A[k]) * static_cast(local_B[k]); - #endif + local_C += ScalarMulFloatOut(local_A[k], local_B[k]); } } } @@ -131,8 +148,19 @@ __global__ void kgemm_4bit_inference_naive( if (row_B < N && warp_lane == 0) out[row_B] = T(local_C); } +bool CheckDims(int m, int k, int block_size) { + if (k % block_size != 0 || m > 1) { + return false; + } + // supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32] + if (block_size % 32 != 0 || block_size > 4096) { + return false; + } + return true; +} + template -bool TryMatMulBnb4( +void Callkgemm_4bit_inference_naive( const T* quant_map, T* output, const T* a_data, @@ -143,22 +171,34 @@ bool TryMatMulBnb4( int k, int block_size, cudaStream_t stream) { - if (k % block_size != 0 || m > 1) { - return false; - } - // supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32] - if (block_size % 32 != 0 || block_size > 4096) { - return false; - } - int lda = k; int ldb = (k + 1) / 2; int ldc = n; int num_blocks = (n + 3) / 4; - constexpr int bits = std::is_same_v ? 16 : 32; + constexpr int bits = std::is_same_v ? 32 : 16; kgemm_4bit_inference_naive<<>>( m, n, k, a_data, b_data_quant, absmax, quant_map, output, lda, ldb, ldc, block_size); +} + +template +bool TryMatMulBnb4( + const T* quant_map, + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream) { + if (!CheckDims(m, k, block_size)) { + return false; + } + + Callkgemm_4bit_inference_naive( + quant_map, output, a_data, b_data_quant, absmax, m, n, k, block_size, stream); return true; } @@ -187,6 +227,42 @@ template bool TryMatMulBnb4( int block_size, cudaStream_t stream); +template <> +bool TryMatMulBnb4( + const BFloat16* quant_map, + BFloat16* output, + const BFloat16* a_data, + const uint8_t* b_data_quant, + const BFloat16* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream) { + if (!CheckDims(m, k, block_size)) { + return false; + } + + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + Callkgemm_4bit_inference_naive( + reinterpret_cast(quant_map), + reinterpret_cast(output), + reinterpret_cast(a_data), + b_data_quant, + reinterpret_cast(absmax), + m, + n, + k, + block_size, + stream); + #else + Callkgemm_4bit_inference_naive( + quant_map, output, a_data, b_data_quant, absmax, m, n, k, block_size, stream); + #endif + + return true; +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/attention.cc b/onnxruntime/contrib_ops/js/bert/attention.cc new file mode 100644 index 0000000000000..723ff00aa815e --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/attention.cc @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "attention.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + Attention, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + Attention); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/attention.h b/onnxruntime/contrib_ops/js/bert/attention.h new file mode 100644 index 0000000000000..0fa823befa9b2 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/attention.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/bert/attention_base.h" +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::contrib::AttentionBase; +using onnxruntime::js::JsKernel; + +class Attention : public JsKernel, AttentionBase { + public: + explicit Attention(const OpKernelInfo& info) : JsKernel(info), AttentionBase(info, false) { + std::vector qkv_sizes(qkv_hidden_sizes_.size()); + if (qkv_hidden_sizes_.size() > 0) { + std::transform(qkv_hidden_sizes_.begin(), qkv_hidden_sizes_.end(), qkv_sizes.begin(), + [](int64_t sz) { return gsl::narrow_cast(sz); }); + } + + JSEP_INIT_KERNEL_ATTRIBUTE(Attention, ({ + "numHeads" : $1, + "isUnidirectional" : $2, + "maskFilterValue" : $3, + "scale" : $4, + "doRotary" : $5, + "qkvHiddenSizes" : $6 ? (Array.from(HEAP32.subarray(Number($7), Number($7) + $6))) : [], + "pastPresentShareBuffer" : !!$8, + }), + static_cast(num_heads_), + static_cast(is_unidirectional_), + static_cast(mask_filter_value_), + static_cast(scale_), + static_cast(do_rotary_), + static_cast(qkv_hidden_sizes_.size()), + reinterpret_cast((qkv_sizes.size() > 0) ? qkv_sizes.data() : nullptr) >> 2, + static_cast(past_present_share_buffer_)); + } +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc b/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc new file mode 100644 index 0000000000000..c43f8b7f18465 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "multi_head_attention.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + MultiHeadAttention, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + MultiHeadAttention); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/multi_head_attention.h b/onnxruntime/contrib_ops/js/bert/multi_head_attention.h new file mode 100644 index 0000000000000..6c63a2ffed4b2 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/multi_head_attention.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/bert/attention_base.h" +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::contrib::AttentionBase; +using onnxruntime::js::JsKernel; + +class MultiHeadAttention : public JsKernel, AttentionBase { + public: + explicit MultiHeadAttention(const OpKernelInfo& info) : JsKernel(info), AttentionBase(info, false) { + JSEP_INIT_KERNEL_ATTRIBUTE(MultiHeadAttention, ({ + "numHeads" : $1, + "isUnidirectional" : $2, + "maskFilterValue" : $3, + "scale" : $4, + "doRotary" : $5, + }), + static_cast(num_heads_), + static_cast(is_unidirectional_), + static_cast(mask_filter_value_), + static_cast(scale_), + static_cast(do_rotary_)); + } +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 24d327576ecd9..498a9f5679eb5 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -7,7 +7,9 @@ namespace onnxruntime { namespace contrib { namespace js { +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); @@ -21,7 +23,9 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo // For std::cerr. + // Writing to stderr instead of logging because logger may not be initialized yet. namespace onnxruntime { @@ -137,7 +138,7 @@ void decodeMIDR( break; // #endif /* ARM */ default: - LOGS_DEFAULT(WARNING) << "unknown ARM CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown ARM CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } } break; @@ -156,7 +157,7 @@ void decodeMIDR( break; // #endif default: - LOGS_DEFAULT(WARNING) << "unknown Broadcom CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Broadcom CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #if (defined(_M_ARM64) || defined(__aarch64__)) && !defined(__ANDROID__) @@ -172,7 +173,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_thunderx2; break; default: - LOGS_DEFAULT(WARNING) << "unknown Cavium CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Cavium CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif @@ -187,7 +188,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_cortex_a76; break; default: - LOGS_DEFAULT(WARNING) << "unknown Huawei CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Huawei CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #if defined(_M_ARM) || defined(__arm__) @@ -199,7 +200,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_xscale; break; default: - LOGS_DEFAULT(WARNING) << "unknown Intel CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Intel CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif /* ARM */ @@ -215,7 +216,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_carmel; break; default: - LOGS_DEFAULT(WARNING) << "unknown Nvidia CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Nvidia CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; #if !defined(__ANDROID__) @@ -225,7 +226,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_xgene; break; default: - LOGS_DEFAULT(WARNING) << "unknown Applied Micro CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Applied Micro CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; #endif @@ -297,7 +298,7 @@ void decodeMIDR( break; // #endif /* ARM64 && !defined(__ANDROID__) */ default: - LOGS_DEFAULT(WARNING) << "unknown Qualcomm CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Qualcomm CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; case 'S': @@ -343,8 +344,9 @@ void decodeMIDR( *uarch = cpuinfo_uarch_exynos_m5; break; default: - LOGS_DEFAULT(WARNING) << "unknown Samsung CPU variant 0x" - << std::hex << midr_get_variant(midr) << " part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Samsung CPU variant 0x" + << std::hex << midr_get_variant(midr) << " part 0x" << std::hex << midr_get_part(midr) + << " ignored\n"; } break; // #if defined(_M_ARM) || defined(__arm__) @@ -355,12 +357,12 @@ void decodeMIDR( *uarch = cpuinfo_uarch_pj4; break; default: - LOGS_DEFAULT(WARNING) << "unknown Marvell CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Marvell CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif /* ARM */ default: - LOGS_DEFAULT(WARNING) << "unknown CPU uarch from MIDR value: 0x" << std::hex << midr; + std::cerr << "unknown CPU uarch from MIDR value: 0x" << std::hex << midr << "\n"; } } diff --git a/onnxruntime/core/framework/op_node_proto_helper.cc b/onnxruntime/core/framework/op_node_proto_helper.cc index 38d67eb0e0c72..c3deb94300e78 100644 --- a/onnxruntime/core/framework/op_node_proto_helper.cc +++ b/onnxruntime/core/framework/op_node_proto_helper.cc @@ -182,7 +182,7 @@ ORT_DEFINE_GET_ATTRS_SPAN_SPECIALIZATION(float, floats) ORT_DEFINE_GET_ATTRS_SPAN_SPECIALIZATION(int64_t, ints) template -MUST_USE_RESULT Status OpNodeProtoHelper::GetAttrs(const std::string& name, TensorShapeVector& out) const { +Status OpNodeProtoHelper::GetAttrs(const std::string& name, TensorShapeVector& out) const { gsl::span span; Status status = this->GetAttrsAsSpan(name, span); if (status.IsOK()) { @@ -193,7 +193,7 @@ MUST_USE_RESULT Status OpNodeProtoHelper::GetAttrs(const std::string& na } template -MUST_USE_RESULT Status OpNodeProtoHelper::GetAttrsStringRefs( +Status OpNodeProtoHelper::GetAttrsStringRefs( const std::string& name, std::vector>& refs) const { const AttributeProto* attr = TryGetAttribute(name); diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index df3a7afebc176..df11fe8302aef 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -455,11 +455,10 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::GraphViewer& // utils::CopyOneInputAcrossDevices is happy. auto& input_map = session_state.GetInputNodeInfoMap(); - auto end_map = input_map.cend(); for (const auto& graph_input : graph_inputs) { const auto& name = graph_input->Name(); - if (input_map.find(name) == end_map) { + if (input_map.find(name) == input_map.cend()) { // dummy entry for an input that we didn't find a use of in the graph. log it in case that's a bug. // utils::CopyOneInputAcrossDevices will use the input OrtValue as is given we don't believe it's used anywhere. LOGS(session_state.Logger(), INFO) << (graph.IsSubgraph() ? "Subgraph" : "Graph") << " input with name " diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index dcde2ddeb8270..b97fb0d2899fc 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -991,7 +991,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( constexpr const char* GroupQueryAttention_ver1_doc = R"DOC( Group Query Self/Cross Attention. -Supports different number of heads for q and kv. +Supports different number of heads for q and kv. Only supports causal or local attention. )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( @@ -1004,10 +1004,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", AttributeProto::FLOAT, OPTIONAL_VALUE) - // .Attr("left_padding_last_token", - // "Copy last token to last index of buffer. Default is 0; 1 when true.", - // AttributeProto::INT, - // OPTIONAL_VALUE) + .Attr("local_window_size", + "left_window_size for local attention (like Mistral). Default value is -1 meaning unused.", + AttributeProto::INT, + static_cast(-1)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size)", @@ -1144,7 +1144,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OPTIONAL_VALUE) .Input(0, "input", - "3D tensor with shape (batch_size, sequence_length, hidden_size)", + "3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)", "T") .Input(1, "position_ids", @@ -1160,7 +1160,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T") .Output(0, "output", - "3D tensor with shape (batch_size, sequence_length, hidden_size)", + "tensor with same shape as input.", "T") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors") diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 39449bea6303a..4c0d78f0ee297 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1375,6 +1375,27 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, GreedySearchShapeInference(ctx); })); +constexpr const char* MoE_ver1_doc = R"DOC( + Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1, + GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, and Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) + usually uses top 32 experts. + )DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1, + OpSchema() + .SetDoc(MoE_ver1_doc) + .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) + .Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1)) + .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") + .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") + .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T") + .Input(3, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size)", "T") + .Input(4, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + .Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional) + .Output(0, "output", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or float16 tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); + ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1, OpSchema() .Input(0, "X", "input", "T") @@ -3410,7 +3431,7 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 .Input(1, "B", "1-dimensional quantized data for weight", "T2") .Input(2, "absmax", "quantization constants", "T1") .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") - .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float/half_float/brain_float tensors.") .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index b35cfc5d12f36..5eef1b33a24dd 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -83,6 +83,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulInteger16); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulFpQ4); #endif class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MaxpoolWithMask); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MoE); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MultiHeadAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupQueryAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); @@ -189,6 +190,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); #endif fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/graph/function_utils.cc b/onnxruntime/core/graph/function_utils.cc index 7b0a834a7ffc0..a266c9ab04a2e 100644 --- a/onnxruntime/core/graph/function_utils.cc +++ b/onnxruntime/core/graph/function_utils.cc @@ -432,7 +432,7 @@ class Inliner { // Process a node: void transform(NodeProto& n) { if (!n.name().empty()) - n.set_name(prefix_ + n.name()); + n.set_name(prefix_ + "_" + n.name()); for (auto& x : *n.mutable_input()) { rename(x, false); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 4b3cafcb39b78..d489a59c4b798 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -984,6 +984,7 @@ bool Node::ClearAttribute(const std::string& attr_name) { graph_->SetGraphProtoSyncNeeded(); return attributes_.erase(attr_name) > 0; } + #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) int Node::PruneRemovableAttributes(gsl::span removable_attributes) { @@ -4047,6 +4048,330 @@ Status Graph::AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& nod return Status::OK(); } +static void ReassignSubgraphDependentNodeArgs(const InlinedHashMap& name_to_nodearg, + Graph& graph) { + for (auto& node : graph.Nodes()) { + if (node.ContainsSubgraph()) { + for (auto& [name, subgraph] : node.GetAttributeNameToMutableSubgraphMap()) { + ReassignSubgraphDependentNodeArgs(name_to_nodearg, *subgraph); + } + } + + // NodeArgs need to be updated + for (auto& input_def : node.MutableInputDefs()) { + if (input_def->Exists()) { + auto hit = name_to_nodearg.find(input_def->Name()); + if (hit != name_to_nodearg.cend()) { + // Make sure we create a local to this subgraph definition + const auto* new_name_arg = hit->second; + input_def = &graph.GetOrCreateNodeArg(new_name_arg->Name(), input_def->TypeAsProto()); + } + } + } + } +} + +Status Graph::InlineIfSubgraph(bool condition_value, Node& if_node, const logging::Logger& logger) { + static const std::string then_branch{"then_branch"}; + static const std::string else_branch{"else_branch"}; + Graph* sub_graph; + if (condition_value) { + sub_graph = if_node.GetMutableGraphAttribute(then_branch); + } else { + sub_graph = if_node.GetMutableGraphAttribute(else_branch); + } + + if (sub_graph == nullptr) { + auto str = MakeString("Unable to constant fold If node: '", if_node.Name(), "' Unable to fetch: ", + (condition_value ? then_branch : else_branch)); + LOGS(logger, WARNING) << str; + return Status::OK(); + } + + Graph& graph_to_inline = *sub_graph; + + std::string unique_id{"_if_"}; + if (condition_value) { + unique_id.append(then_branch); + } else { + unique_id.append(else_branch); + } + + unique_id = GenerateNodeName(unique_id); + + auto make_unique = [&unique_id](const std::string& name) { + return unique_id + '_' + name; + }; + + // Check if the name is an input or implicit input. + // These are not renamed, and we do not need to adjust subgraphs for them. + // Implicit inputs would cover both If node input and implicit inputs. + // Reason: there are no explicit inputs to the subgraphs, and the subgraph's + // implicit inputs must be covered by the implicit inputs of the If node. + InlinedHashMap outer_scope_values; + const auto& if_implicit_inputs = if_node.MutableImplicitInputDefs(); + outer_scope_values.reserve(if_implicit_inputs.size()); + + for (auto* input : if_implicit_inputs) { + const auto& name = input->Name(); + ORT_IGNORE_RETURN_VALUE(outer_scope_values.emplace(name, input)); + } + + // Name mapping from the graph to inline to the graph we are inlining into + // we also use this to process any subgraphs in the graph we are inlining + InlinedHashMap name_to_nodearg; + + // We are going to map the outputs of the graph to inline to the outputs of the If node. + // They are assumed to be in the same order. + const auto& node_output_defs = if_node.MutableOutputDefs(); + const auto& graph_output_defs = graph_to_inline.GetOutputs(); + for (size_t i = 0; i < graph_output_defs.size(); ++i) { + name_to_nodearg.emplace(graph_output_defs[i]->Name(), node_output_defs[i]); + } + + // Move initializers from the subgraph to the destination graph. + for (int i = 0, limit = graph_to_inline.graph_proto_->initializer_size(); i < limit; ++i) { + auto* initializer = graph_to_inline.graph_proto_->mutable_initializer(i); + const std::string src_name = initializer->name(); + +#if !defined(DISABLE_SPARSE_TENSORS) + bool has_sparse_origin = false; + if (!graph_to_inline.sparse_tensor_names_.empty()) { + auto hit = graph_to_inline.sparse_tensor_names_.find(src_name); + if (hit != graph_to_inline.sparse_tensor_names_.cend()) { + has_sparse_origin = true; + // Erase the entry that will be invalidated + graph_to_inline.sparse_tensor_names_.erase(hit); + } + } +#endif + + graph_to_inline.name_to_initial_tensor_.erase(src_name); + const gsl::not_null tensor{graph_proto_->add_initializer()}; + *tensor = std::move(*initializer); + + // Check if this is an output of the graph + auto hit = name_to_nodearg.find(src_name); + if (hit != name_to_nodearg.cend()) { + // We rename it to If node output. + tensor->set_name(hit->second->Name()); + } else { + NodeArg* node_arg = graph_to_inline.GetNodeArg(src_name); + assert(node_arg != nullptr); + auto new_name = GenerateNodeArgName(make_unique(src_name)); + NodeArg& new_arg = GetOrCreateNodeArg(new_name, node_arg->TypeAsProto()); + ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(src_name, &new_arg)); + tensor->set_name(std::move(new_name)); + } + + auto insert_result = name_to_initial_tensor_.emplace(tensor->name(), tensor); + ORT_ENFORCE(insert_result.second, "Initializer name: ", tensor->name(), " from graph: ", + graph_to_inline.Name(), " conflicts with graph initializer. Check name generation above."); + +#if !defined(DISABLE_SPARSE_TENSORS) + if (has_sparse_origin) { + ORT_IGNORE_RETURN_VALUE(sparse_tensor_names_.emplace(tensor->name())); + } +#endif + } + + // Look up nodes that would be providing input to our nodes (implicit and explicit) + // and any nodes that take the output of our nodes (used to be If output) + // Map of NodeArg name to pair of Node* and input index in the destination node + using NodeAndIndex = std::pair, int>; + using ArgNameToNodeMap = InlinedHashMap; + ArgNameToNodeMap input_args; + // Map of NodeArg name to pair of Node* and output index in the source node. + ArgNameToNodeMap output_args; + + auto map_defs = [](Node& node, ArgNameToNodeMap& map, bool input) { + const auto defs = (input) ? node.InputDefs() : node.OutputDefs(); + map.reserve(map.size() + defs.size()); + int arg_pos = -1; + for (auto* node_arg : defs) { + ++arg_pos; + if (node_arg->Exists()) { + map.emplace(node_arg->Name(), std::make_pair(&node, arg_pos)); + } + } + }; + + const bool is_this_main_graph = (parent_graph_ == nullptr); + // Map the inputs and outputs of the If node to the nodes in the graph to inline. + if (!is_this_main_graph) { + for (auto& node : Nodes()) { + if (node.Index() == if_node.Index()) { + continue; + } + map_defs(node, input_args, true); + map_defs(node, output_args, false); + } + } + + auto* non_existing_arg = &GetOrCreateNodeArg(std::string(), nullptr); + // We want to make sure we get nodes in topological order + // because Constant folding may cause the nodes appear in + // a different order. + InlinedVector new_nodes; + GraphViewer graph(graph_to_inline); + for (const auto node_idx : graph.GetNodesInTopologicalOrder()) { + // GraphViewer filters out nullptrs + auto* node = graph_to_inline.GetNode(node_idx); + assert(node->OpType() != kConstant); + + // Inputs + // Chop off trailing non-existing defs, but preserve non-existing in the middle + auto& input_defs = node->MutableInputDefs(); + auto last_existing = std::find_if(input_defs.rbegin(), input_defs.rend(), + [](const NodeArg* node_arg) { return node_arg->Exists(); }); + input_defs.resize(std::distance(input_defs.begin(), last_existing.base())); + + InlinedVector new_input_defs; + for (auto* input_def : node->InputDefs()) { + if (input_def->Exists()) { + // Check if this is one of the implicit graph inputs + // then re-assign the def to the outer scope value. + const auto& input_name = input_def->Name(); + auto outer_hit = outer_scope_values.find(input_name); + if (outer_hit != outer_scope_values.cend()) { + // get/create local definition + NodeArg* outer_arg = outer_hit->second; + auto& this_scope_arg = GetOrCreateNodeArg(outer_arg->Name(), input_def->TypeAsProto()); + new_input_defs.push_back(&this_scope_arg); + } else { + auto hit = name_to_nodearg.find(input_name); + if (hit != name_to_nodearg.cend()) { + // This is other node output in the dest graph, + // constant node or initializer that was renamed. + new_input_defs.push_back(hit->second); + } else { + ORT_THROW("Node's: ", node->Name(), " input: ", input_name, + " is not If node's input or previous node output in this subgraph"); + } + } + } else { + new_input_defs.push_back(non_existing_arg); + } + } + + // Outputs + // Chop off trailing non-existing defs + auto& output_defs = node->MutableOutputDefs(); + last_existing = std::find_if(output_defs.rbegin(), output_defs.rend(), + [](const NodeArg* node_arg) { return node_arg->Exists(); }); + output_defs.resize(std::distance(output_defs.begin(), last_existing.base())); + + InlinedVector new_output_defs; + for (auto* output_def : node->OutputDefs()) { + if (output_def->Exists()) { + const auto& output_name = output_def->Name(); + auto hit = name_to_nodearg.find(output_name); + if (hit != name_to_nodearg.cend()) { + // This is one of the If node outputs, simply reassign the def. + // If node defs are already in the destination graph + new_output_defs.push_back(hit->second); + } else { + // We generate an output to downstream nodes. + auto new_name = GenerateNodeArgName(make_unique(output_name)); + NodeArg& new_arg = GetOrCreateNodeArg(new_name, output_def->TypeAsProto()); + new_output_defs.push_back(&new_arg); + ORT_IGNORE_RETURN_VALUE(name_to_nodearg.emplace(output_name, &new_arg)); + } + } else { + new_output_defs.push_back(non_existing_arg); + } + } + + const auto new_node_name = GenerateNodeName(make_unique(node->OpType())); + Node& new_node = AddNode(new_node_name, node->OpType(), node->Description(), + new_input_defs, + new_output_defs, + nullptr, + node->Domain()); + + new_node.SetSinceVersion(node->SinceVersion()); + new_node.op_ = node->op_; + + if (!is_this_main_graph) { + map_defs(new_node, input_args, true); + map_defs(new_node, output_args, false); + new_nodes.push_back(&new_node); + } + + if (node->ContainsSubgraph()) { + auto& subgraphs = node->MutableSubgraphs(); + + // Check if any of this node implicit inputs of this graph is in the renaming map + // that would mean they come from the destination graph, not from the parent + // of the destination graph. + int renames_subgraph_names = 0; + auto& implicit_defs = node->MutableImplicitInputDefs(); + for (auto& input_def : implicit_defs) { + auto hit = name_to_nodearg.find(input_def->Name()); + if (hit != name_to_nodearg.cend()) { + input_def = hit->second; + ++renames_subgraph_names; + } + } + + for (auto& subgraph : subgraphs) { + if (renames_subgraph_names > 0) { + // We need to rename the subgraph node names + // because they may refer to the implicit inputs + // that were renamed. + ReassignSubgraphDependentNodeArgs(name_to_nodearg, *subgraph); + } + subgraph->parent_node_ = &new_node; + subgraph->parent_graph_ = this; + } + + new_node.MutableSubgraphs() = std::move(subgraphs); + new_node.GetMutableMapOfAttributeNameToSubgraph() = std::move(node->GetMutableMapOfAttributeNameToSubgraph()); + new_node.MutableImplicitInputDefs() = std::move(implicit_defs); + } + + new_node.GetMutableAttributes() = std::move(node->GetMutableAttributes()); + } + + // Let's rebuild local connections, so next time a GraphViewer is able to perform topological sort. + // We only need to do so if this graph is not the main graph, because the main graph is going to resolve + // and it is not possible to inline the same nodes again. + if (!is_this_main_graph) { + for (auto* node : new_nodes) { + int arg_pos = -1; + for (auto* input_def : node->InputDefs()) { + ++arg_pos; + auto hit = output_args.find(input_def->Name()); + if (hit != output_args.cend()) { + // The input to this node is an output from a previous node in this graph. + // Create relationship between this node (node), and the node providing the output (output_node). + const auto& [producer, src_idx] = hit->second; + AddEdge(producer->Index(), node->Index(), src_idx, arg_pos); + } + } + + // Check if any of the outputs for inlined nodes are inputs to other nodes in the graph. + // (outputs of If node) + arg_pos = -1; + for (auto& output_def : node->OutputDefs()) { + ++arg_pos; + auto hit = input_args.find(output_def->Name()); + if (hit != input_args.cend()) { + // The output of this node is an input to another node in this graph. + // Create relationship between this node (node), and the node using the input (input_node). + const auto& [consumer, dst_idx] = hit->second; + AddEdge(node->Index(), consumer->Index(), arg_pos, dst_idx); + } + } + } + } + + LOGS(logger, INFO) << "Constant folded (inlined) " << (condition_value ? then_branch : else_branch) + << " for If node: " << if_node.Name(); + + return Status::OK(); +} + Status Graph::InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline) { auto to_node_arg = [this](const std::string& name) { return &this->GetOrCreateNodeArg(name, nullptr); diff --git a/onnxruntime/core/mlas/.clang-format b/onnxruntime/core/mlas/.clang-format index 4a89ef98cf049..16ad8bd8a7234 100644 --- a/onnxruntime/core/mlas/.clang-format +++ b/onnxruntime/core/mlas/.clang-format @@ -2,10 +2,12 @@ BasedOnStyle: Google IndentWidth: 4 -ColumnLimit: 100 +# Setting ColumnLimit to 0 so developer choices about where to break lines are maintained. +# Developers are responsible for adhering to the 120 character maximum. +ColumnLimit: 0 +AlignAfterOpenBracket: BlockIndent AlwaysBreakAfterReturnType: TopLevel AlwaysBreakTemplateDeclarations: Yes BinPackParameters: false BreakBeforeBraces: Linux ... - diff --git a/onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h b/onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h new file mode 100644 index 0000000000000..7ea29eb091318 --- /dev/null +++ b/onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h @@ -0,0 +1,33 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + mlas_gemm_postprocessor.h + +Abstract: + + This module contains a base class for custom postprocessing following a + GEMM. + +--*/ + +#pragma once + +template +class MLAS_GEMM_POSTPROCESSOR +{ + public: + virtual void Process(T* C, /**< the address of matrix to process */ + size_t RangeStartM, /**< the start row index of matrix */ + size_t RangeStartN, /**< the start col index of matrix */ + size_t RangeCountM, /**< the element count per row to process */ + size_t RangeCountN, /**< the element count per col to process */ + size_t ldc /**< the leading dimension of matrix */ + ) const = 0; + + virtual ~MLAS_GEMM_POSTPROCESSOR() {} +}; diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index 7c7b729117e4a..316344ad8c214 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -21,6 +21,7 @@ Module Name: #pragma once #include "mlas.h" +#include "mlas_gemm_postprocessor.h" #include #include @@ -95,22 +96,6 @@ MlasQ4GemmUnPackB( ); -template -class MLAS_GEMM_POSTPROCESSOR -{ - public: - virtual void Process(T*, /**< the address of matrix to process */ - size_t, /**< the start row index of matrix */ - size_t, /**< the start col index of matrix */ - size_t, /**< the element count per row to process */ - size_t, /**< the element count per col to process */ - size_t /**< the leading dimension of matrix */ - ) const = 0; - - virtual ~MLAS_GEMM_POSTPROCESSOR() {} -}; - - /** * @brief Data parameters for Q4 GEMM routine * C = A * B + Bias @@ -241,7 +226,7 @@ MlasQ8Q4GemmBatch( * matrix shape [rows, columns], compute the shape of the * quantization parameter matrix [meta_rows, meta_cols] */ -template +template void MlasBlockwiseQuantMetaShape( int block_size, @@ -259,6 +244,7 @@ MlasBlockwiseQuantMetaShape( * is in column major layout, with bits packed on the column. * * @tparam T + * @tparam qbits * @param block_size * @param columnwise * @param rows @@ -266,7 +252,7 @@ MlasBlockwiseQuantMetaShape( * @param q_rows * @param q_cols */ -template +template void MlasBlockwiseQuantizedShape( int block_size, @@ -277,6 +263,32 @@ MlasBlockwiseQuantizedShape( int& q_cols ); +/** + * @brief Compute the sizes of the quantized data and quantization parameter buffers. + * + * @param qbits The bit width of each quantized value. + * @param block_size The number of quantized values in a block. + * @param columnwise Whether a block contains values from a matrix column (true) or row (false). + * @param rows Number of matrix rows. + * @param columns Number of matrix columns. + * @param[out] q_data_size_in_bytes The size in bytes of the quantized data. + * @param[out] q_scale_num_elements The size in elements of the scale quantization parameters. + * @param[out] q_zero_point_size_in_bytes The size in bytes of the zero point quantization parameters. Optional. + * + * If the qbits or block_size values are unsupported the output sizes will be zero. + */ +void MLASCALL +MlasBlockwiseQuantizedBufferSizes( + int qbits, + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +); + /** * @brief Blockwise 4 bits quantization, resulting elements and quantization diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h new file mode 100644 index 0000000000000..9620dd42d1da9 --- /dev/null +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -0,0 +1,79 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + mlas_qnbit.h + +Abstract: + + This module contains the public data structures and procedure prototypes + for blocked n-bit quantized GEMM. + + N-bit block quantization is used to compress weight tensors of large + language models. + +--*/ + +#pragma once + +#include "mlas.h" +#include "mlas_gemm_postprocessor.h" + +/** + * @brief Data parameters for float/n-bit quantized int GEMM routine. + */ +struct MLAS_SQNBIT_GEMM_DATA_PARAMS { + const float* A = nullptr; ///< address of A (float32 matrix) + size_t lda = 0; ///< leading dimension of A + const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values) + const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block + bool IsBPacked = false; ///< whether B values are packed in an optimized format for the computation + const float* Bias = nullptr; ///< optional address of Bias, vector size N + float* C = nullptr; ///< address of result matrix + size_t ldc = 0; ///< leading dimension of C + + ///< optional post processing to apply to result matrix + MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; +}; + +/** + * @brief Batched GEMM: C = A * B + Bias + * A must be a float32 matrix + * B must be a quantized and packed n-bit int matrix + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool optional thread pool to use + */ +void MLASCALL +MlasSQNBitGemmBatch( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkBitWidth, + size_t BlkLen, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool = nullptr +); + +/** + * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + */ +bool MLASCALL +MlasIsSQNBitGemmAvailable( + size_t BlkBitWidth, + size_t BlkLen +); diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index e0c2772cbb719..6c859e4e4f44b 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -890,14 +890,30 @@ extern const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchNeon; extern const MLAS_CONV_SYM_DISPATCH MlasConvSymU8DispatchDot; extern const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchDot; +// +// Quantized 8-bit integer/quantized 4-bit integer matrix/matrix multiply dispatch structure. +// + struct MLAS_Q8Q4GEMM_DISPATCH; extern const MLAS_Q8Q4GEMM_DISPATCH MlasQ8Q4GemmDispatchAvx512vnni; +// +// Float/quantized 4-bit integer matrix/matrix multiply dispatch structure. +// + struct MLAS_FPQ4GEMM_DISPATCH; extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512; +// +// Float/quantized n-bit integer matrix/matrix multiply dispatch structure. +// + +struct MLAS_SQNBIT_GEMM_DISPATCH; + +extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; + // // Quantized depthwise convolution kernels. // @@ -1029,6 +1045,8 @@ struct MLAS_PLATFORM { const MLAS_FPQ4GEMM_DISPATCH* FpQ4GemmDispatch{nullptr}; const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr}; + + const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr}; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 39586282e00ad..fec56c6ee063f 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -460,6 +460,7 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index fbd1030de8ab7..48d975a7fd26d 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -422,6 +422,24 @@ struct BlockwiseQuantizer { q_cols = meta_cols * QuantBlk::kColumn; } + static MLAS_FORCEINLINE void quantizedBufferSizes( + int rows, int columns, size_t& data_bytes, size_t& scale_num_elements, size_t* zero_point_bytes + ) + { + int meta_rows, meta_cols; + quantizeMetaShape(rows, columns, meta_rows, meta_cols); + int q_rows, q_cols; + quantizedShape(rows, columns, q_rows, q_cols); + + data_bytes = q_rows * q_cols; + scale_num_elements = meta_rows * meta_cols; + + if (zero_point_bytes) { + // this works for qbits == 4 but may need to be updated for other qbits values + *zero_point_bytes = ((meta_rows * qbits + 7) / 8) * meta_cols; + } + } + /** * @brief Quantized a Matrix shape [rows, columns], resulting quantized * and packed data are stored in column major (transposed) @@ -621,7 +639,7 @@ struct BlockwiseQuantizer { }; -template +template void MlasBlockwiseQuantMetaShape( int block_size, @@ -635,47 +653,47 @@ MlasBlockwiseQuantMetaShape( switch (block_size) { case 16: { if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } break; } case 32: { if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape( + BlockwiseQuantizer::quantizeMetaShape( rows, columns, meta_rows, meta_cols); } break; } case 64: { if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } break; } case 128: { if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } break; } case 256: { if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } break; @@ -689,7 +707,7 @@ MlasBlockwiseQuantMetaShape( -template +template void MlasBlockwiseQuantizedShape( int block_size, @@ -703,42 +721,42 @@ MlasBlockwiseQuantizedShape( switch (block_size) { case 16: { if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } else { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } break; } case 32: { if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } else { - BlockwiseQuantizer::quantizedShape( + BlockwiseQuantizer::quantizedShape( rows, columns, q_rows, q_cols); } break; } case 64: { if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } else { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } break; } case 128: { if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } else { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } break; } case 256: { if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } else { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } break; } @@ -752,7 +770,7 @@ MlasBlockwiseQuantizedShape( template void -MlasBlockwiseQuantMetaShape( +MlasBlockwiseQuantMetaShape( int block_size, bool columnwise, int rows, @@ -763,7 +781,7 @@ MlasBlockwiseQuantMetaShape( template void -MlasBlockwiseQuantizedShape( +MlasBlockwiseQuantizedShape( int block_size, bool columnwise, int rows, @@ -773,6 +791,93 @@ MlasBlockwiseQuantizedShape( ); +void MLASCALL +MlasBlockwiseQuantizedBufferSizes( + int qbits, + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +) +{ + q_data_size_in_bytes = q_scale_num_elements = 0; + if (q_zero_point_size_in_bytes) { + *q_zero_point_size_in_bytes = 0; + } + + if (qbits == 4) { + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 32: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 64: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 128: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 256: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; + } + } +} + + template void MlasQuantizeBlockwise( diff --git a/onnxruntime/core/mlas/lib/q4gemm.h b/onnxruntime/core/mlas/lib/q4gemm.h index 1562f9c0b4236..b1b51dd53c4fc 100644 --- a/onnxruntime/core/mlas/lib/q4gemm.h +++ b/onnxruntime/core/mlas/lib/q4gemm.h @@ -90,7 +90,7 @@ MlasQ4GemmOperation( if (DataParams->OutputProcessor != nullptr) { DataParams->OutputProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, RowsHandled, CountN, ldc); } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp new file mode 100644 index 0000000000000..f964b1affec31 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -0,0 +1,144 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch. +--*/ + +#include "sqnbitgemm.h" + +namespace +{ + +// Get quantization variant based on `BlkBitWidth` and `BlkLen`. +// Return -1 if the input values are unsupported. +int32_t +GetDispatchQuantVariant(size_t BlkBitWidth, size_t BlkLen) +{ + int32_t type = -1; + if (BlkBitWidth == 4 && BlkLen == 16) { + type = QuantVariant_BitWidth4_BlockSize16; + } else if (BlkBitWidth == 4 && BlkLen == 32) { + type = QuantVariant_BitWidth4_BlockSize32; + } else if (BlkBitWidth == 4 && BlkLen == 64) { + type = QuantVariant_BitWidth4_BlockSize64; + } else if (BlkBitWidth == 4 && BlkLen == 128) { + type = QuantVariant_BitWidth4_BlockSize128; + } else if (BlkBitWidth == 4 && BlkLen == 256) { + type = QuantVariant_BitWidth4_BlockSize256; + } + + return type; +} + +} // namespace + +void MLASCALL +MlasSQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool +) +{ + const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen); + MLAS_SQNBIT_GEMM_OPERATION* const Operation = GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant]; + + if (ThreadPool == nullptr) { + for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { + auto Data = &DataParams[gemm_i]; + Operation(K, Data, 0, M, 0, N); + } + return; + } + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K) * double(BatchN); + + ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool) * 8; + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN; + if (ThreadsPerGemm < 1) { + ThreadsPerGemm = 1; + } + + constexpr size_t StrideM = 128; + + size_t nc = N; + if (ThreadsPerGemm > 1) { + // more than one thread per GEMM + + const size_t BlockedM = MlasDivRoundup(M, StrideM); + const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); + if (max_nc < nc) { + nc = std::min( + nc, MlasDivRoundup(max_nc, MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * + MLAS_QGEMM_STRIDEN_THREAD_ALIGN + ); + } + } + const size_t StrideN = nc; + + const size_t ThreadCountM = MlasDivRoundup(M, StrideM); + const size_t ThreadCountN = MlasDivRoundup(N, StrideN); + ThreadsPerGemm = ThreadCountM * ThreadCountN; + + MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { + const auto gemm_i = tid / ThreadsPerGemm; + const auto blk_i = tid % ThreadsPerGemm; + auto Data = &DataParams[gemm_i]; + + const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; + const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; + + const size_t RangeStartM = ThreadIdM * StrideM; + const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); + + const size_t RangeStartN = ThreadIdN * StrideN; + const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); + + Operation(K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + }); +} + +bool MLASCALL +MlasIsSQNBitGemmAvailable( + size_t BlkBitWidth, + size_t BlkLen +) +{ + const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen); + if (QuantVariant == -1) { + return false; + } + + if (GetMlasPlatform().SQNBitGemmDispatch == nullptr || + GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant] == nullptr) { + return false; + } + + return true; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h new file mode 100644 index 0000000000000..f8f7dcd43699f --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -0,0 +1,287 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm.h + +Abstract: + + This module includes: + + - Declaration of the set of template functions used to implement a kernel + for a matrix/matrix multiplication, A*B, where A is a float matrix and B is + a n-bit quantized integer matrix (QNBitGemm). + + - A shared kernel driver function template, MlasSQNBitGemmOperation. + + - Kernel dispatch structure. + + The B matrix is block quantized, which means that its values are grouped + into blocks which each have one scale and optional zero point. Each + quantized value in B is n-bits wide. + +--*/ + +#pragma once + +#include "mlas_qnbit.h" +#include "mlasi.h" + +// +// Kernel implementation template declarations +// + +/** + * @brief Multiply float matrix A with quantized n-bit integer matrix B. + * B is block quantized and column major. + * This kernel handles the special case where M, the number of rows of A and C, is 1. + * + * @tparam BlkBitWidth Bit width of each value in a block. + * @tparam BlkLen Number of values in a block. + * @tparam KernelType Hardware-specific kernel type. + * + * @param A Supplies the A matrix. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + */ +template +MLAS_FORCEINLINE void +MlasSQNBitGemmM1Kernel( + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +); + +/** + * @brief Dequantize B into the format expected by the Sgemm kernel. + * B is block quantized and column major. + * This is equivalent to dequantizing B and then running + * MlasSgemmCopyPackB. + * + * @tparam BlkBitWidth Bit width of each value in a block. + * @tparam BlkLen Number of values in a block. + * @tparam KernelType Hardware-specific kernel type. + * + * @param[out] FpData Supplies the output buffer for the dequantized B float data. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param CountN Number of columns of B. + * @param CountK Number of rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + */ +template +MLAS_FORCEINLINE void +MlasQNBitBlkDequantBForSgemm( + float* FpData, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB +); + +// +// MlasQNBitGemmOperation and helpers +// + +constexpr MLAS_FORCEINLINE size_t +MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) +{ + return BlkLen * BlkBitWidth / 8; +} + +template +constexpr MLAS_FORCEINLINE size_t +MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) +{ + if constexpr (BlkBitWidth <= 4) { + return MlasDivRoundup(BlkCount, 2); // 2 blocks per byte + } else { + return BlkCount; + } +} + +MLAS_FORCEINLINE void +MlasAddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc) +{ + for (size_t m = 0; m < CountM; m++) { + const float* bias = Bias; + float* sum = C; + for (size_t n = 0; n < CountN; n += 4) { + if (CountN - n < 4) { + for (size_t nn = n; nn < CountN; nn++) { + *sum += *bias; + sum++; + bias++; + } + break; + } + + MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum); + acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias)); + MlasStoreFloat32x4(sum, acc_x); + bias += 4; + sum += 4; + } + C += ldc; + } +} + +template +MLAS_FORCEINLINE void MLASCALL +MlasSQNBitGemmOperation( + const size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const float* A = DataParams->A + RangeStartM * lda; + + const uint8_t* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const uint8_t* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; + + if (RangeCountM == 1) { + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const float* a_row = A; + const uint8_t* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const uint8_t* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + MlasSQNBitGemmM1Kernel( + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } + return; + } + + constexpr size_t StrideN = 32; + size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float); + MlasThreadedBufAlloc(bufsize); + auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); + // + // Step through each slice of matrix B along the N dimension. + // + + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, StrideN); + + // + // Step through each slice of matrix A along the M dimension. + // + const float* a_row = A; + const uint8_t* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const uint8_t* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + MlasQNBitBlkDequantBForSgemm( + dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks + ); + + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true + ); +#else + auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f); +#endif + + if (bias) { + MlasAddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); + } + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + RowsHandled, CountN, ldc + ); + } + + c_blk += ldc * RowsHandled; + a_row += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } + } +} + +// +// Kernel dispatch structure. +// + +typedef void(MLASCALL MLAS_SQNBIT_GEMM_OPERATION)( + size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + size_t RangeStartM, + size_t RangeCountM, + size_t RangeStartN, + size_t RangeCountN +); + +enum QuantVariant { + QuantVariant_BitWidth4_BlockSize16, + QuantVariant_BitWidth4_BlockSize32, + QuantVariant_BitWidth4_BlockSize64, + QuantVariant_BitWidth4_BlockSize128, + QuantVariant_BitWidth4_BlockSize256, + QuantVariantCount, // Keep this element last and ensure that its value is the number of other QuantVariant values. + // Its value is used as an array size. +}; + +struct MLAS_SQNBIT_GEMM_DISPATCH { + MLAS_SQNBIT_GEMM_OPERATION* Operations[QuantVariantCount] = { + // Initialized to nullptrs. Overwrite in hardware-specific kernel implementation. + }; +}; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp new file mode 100644 index 0000000000000..63afe57dd9137 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -0,0 +1,489 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon.h + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON. + +--*/ + +#include "sqnbitgemm.h" + +#include + +#include +#include +#include + +// +// Hardware-specific kernel type. +// +struct MLAS_SQNBIT_GEMM_KERNEL_NEON { +}; + +namespace +{ + +template +MLAS_FORCEINLINE void +UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) +{ + (f(Indices), ...); +} + +template +MLAS_FORCEINLINE void +UnrolledLoop(IterationFn&& f) +{ + UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); +} + +MLAS_FORCEINLINE float32x4_t +FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) +{ + // aN: aN_0 aN_1 aN_2 aN_3 + + float32x4_t b0 = vzip1q_f32(a0, a1); // a0_0 a1_0 a0_1 a1_1 + float32x4_t b1 = vzip2q_f32(a0, a1); // a0_2 a1_2 a0_3 a1_3 + float32x4_t b2 = vzip1q_f32(a2, a3); // a2_0 a3_0 a2_1 a3_1 + float32x4_t b3 = vzip2q_f32(a2, a3); // a2_2 a3_2 a2_3 a3_3 + + // a0_0 a1_0 a2_0 a3_0 + a0 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); + // a0_1 a1_1 a2_1 a3_1 + a1 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); + // a0_2 a1_2 a3_2 a3_2 + a2 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); + // a0_3 a1_3 a2_3 a3_3 + a3 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); + + return vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3)); +} + +template +MLAS_FORCEINLINE void +LoadData(const float* src, size_t count, float32x4_t (& dst)[Capacity / 4]) +{ + static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4."); + + assert(count <= Capacity); + + size_t vi = 0; // vector index + + // handle 4 values at a time + while (count > 3) { + dst[vi] = vld1q_f32(src); + + vi += 1; + src += 4; + count -= 4; + } + + // handle remaining values + if (count > 0) { + dst[vi] = vsetq_lane_f32(src[0], dst[vi], 0); + + if (count > 1) { + dst[vi] = vsetq_lane_f32(src[1], dst[vi], 1); + + if (count > 2) { + dst[vi] = vsetq_lane_f32(src[2], dst[vi], 2); + } + } + } +} + +template +MLAS_FORCEINLINE void +ComputeDotProducts( + const float* ARowPtr, + const uint8_t* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const uint8_t* QuantBZeroPointColPtr, + float* SumPtr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* BiasPtr +) +{ + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); + + const uint8x8_t LowMask = vdup_n_u8(0x0F); + + // Manual conversion to float takes place in two steps: + // 1. Map 4-bit values from [0, 15] to float values from [16.0f, 31.0f]. + // This target float range is convenient because the 4-bit source values can be placed directly into the + // target float bits. + // 2. Subtract the conversion offset of 16 from the float result. + + // The high 16 bits of an IEEE 754 32-bit float used as a template for creating float values. + constexpr uint16_t float_high_half_template = 0b0'10000011'0000000; + // sign|exponent|partial mantissa + // +|131: 2^4|~~~~ <- 4 bits go here + + const uint16x8_t float_high_half_template_v = vdupq_n_u16(float_high_half_template); + + float32x4_t acc[NCols]{}; + + const uint8_t* QuantBData = QuantBDataColPtr; + const float* QuantBScale = QuantBScaleColPtr; + size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + float scale[NCols]; + UnrolledLoop( + [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } + ); + + float offset[NCols]; // Includes zero point and float conversion offset of 16. + if (QuantBZeroPointColPtr != nullptr) { + UnrolledLoop([&](size_t i) { + const uint8_t zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const uint8_t zp = ((QuantBZeroPointIdx & 1) == 1) ? (zp_packed >> 4) : (zp_packed & 0x0F); + offset[i] = 16.0f + zp; + }); + } else { + UnrolledLoop([&](size_t i) { + constexpr float zp = 8.0f; + offset[i] = 16.0f + zp; + }); + } + + constexpr size_t SubBlkLen = 16; // number of block elements to process in one iteration + + for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { + // load A row vector elements + + // load `SubBlkLen` elements from A, padded with 0's if there aren't enough + const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, SubBlkLen); + float32x4_t av[4]{}; + LoadData(ARowPtr + k + k_idx_in_blk, k_subblk_len, av); + + // load B column vectors + uint8x8_t bv_packed[NCols]; + UnrolledLoop([&](size_t i) { + const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; + bv_packed[i] = vld1_u8(QuantBData + i * StrideQuantBData + b_data_block_offset); + }); + + uint8x8_t bv_u8_unzipped[NCols][2]; + UnrolledLoop([&](size_t i) { + bv_u8_unzipped[i][0] = vand_u8(bv_packed[i], LowMask); + bv_u8_unzipped[i][1] = vand_u8(vshr_n_u8(bv_packed[i], 4), LowMask); + }); + + uint8x8_t bv_u8[NCols][2]; + UnrolledLoop([&](size_t i) { + bv_u8[i][0] = vzip1_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1]); + bv_u8[i][1] = vzip2_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1]); + }); + + // dequantize B + + // shift left 3 and widen to 16 bits + uint16x8_t bv_u16[NCols][2]; + UnrolledLoop([&](size_t i) { + constexpr int shift = 3; + bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); + bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); + }); + + // combine 4 bits with float high half template + UnrolledLoop([&](size_t i) { + bv_u16[i][0] = vorrq_u16(bv_u16[i][0], float_high_half_template_v); + bv_u16[i][1] = vorrq_u16(bv_u16[i][1], float_high_half_template_v); + }); + + // `SubBlkLen` floats of B + float32x4_t bv[NCols][4]; + + // shift left 16, widen to 32 bits, and reinterpret as float + UnrolledLoop([&](size_t i) { + constexpr int shift = 16; + bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); + bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); + + bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); + bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); + }); + + // subtract float conversion offset (16) and zero point + UnrolledLoop([&](size_t i) { + const float32x4_t offset_v = vdupq_n_f32(offset[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + + // multiply by scale + UnrolledLoop([&](size_t i) { + const float32x4_t scale_v = vdupq_n_f32(scale[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); + }); + + // c[m,n] += a[m,k] * b[k,n] + UnrolledLoop<4>([&](size_t j) { + UnrolledLoop([&](size_t i) { acc[i] = vfmaq_f32(acc[i], av[j], bv[i][j]); }); + }); + } + + // increment pointers to next block + QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + QuantBScale += 1; + QuantBZeroPointIdx += 1; + } + + if constexpr (NCols == 4) { + float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + + if (BiasPtr != nullptr) { + sum = vaddq_f32(sum, vld1q_f32(BiasPtr)); + } + + vst1q_f32(SumPtr, sum); + } else { + for (size_t i = 0; i < NCols; ++i) { + SumPtr[i] = vaddvq_f32(acc[i]); + if (BiasPtr != nullptr) { + SumPtr[i] += BiasPtr[i]; + } + } + } +} + +} // namespace + +// +// MlasSQNBitGemmKernel and helpers. +// + +template +MLAS_FORCEINLINE void +MlasSQNBitGemmM1KernelNeon( + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t NCols = 4; + + const float* ARowPtr = A; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const uint8_t* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols; + + while (nblk >= 0) { + ComputeDotProducts( + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next `NCols` columns + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if (QuantBZeroPointColPtr != nullptr) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + + nblk -= NCols; + } + + // left over columns less than `NCols`? + nblk += NCols; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts( + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if (QuantBZeroPointColPtr != nullptr) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +#define SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(BlkBitWidth, BlkLen) \ + template <> \ + MLAS_FORCEINLINE void \ + MlasSQNBitGemmM1Kernel( \ + const float* A, \ + const uint8_t* QuantBData, \ + const float* QuantBScale, \ + const uint8_t* QuantBZeroPoint, \ + float* C, \ + size_t CountN, \ + size_t CountK, \ + size_t BlockStrideQuantB, \ + const float* Bias \ + ) \ + { \ + return MlasSQNBitGemmM1KernelNeon( \ + A, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, \ + BlockStrideQuantB, Bias \ + ); \ + } + +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 16) +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 32) +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 64) +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 128) +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 256) + +#undef SPECIALIZE_SQNBIT_GEMM_M1_KERNEL + +// +// MlasQNBitBlkDequantBForSgemm and helpers. +// + +template +MLAS_FORCEINLINE void +MlasQNBitBlkDequantBForSgemmNeon( + float* FpData, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB +) +{ + auto impl0_reference = [&]() { + static_assert(BlkBitWidth == 4); + + float* Dst = FpData; + + const uint8_t* QuantBDataCol = QuantBData; + const float* QuantBScaleCol = QuantBScale; + const uint8_t* QuantBZeroPointCol = QuantBZeroPoint; + + for (size_t n = 0; n < CountN; n += 16) { + const size_t nnlen = std::min(CountN - n, size_t{16}); + + for (size_t nn = 0; nn < nnlen; ++nn) { + for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, k_blk_idx += 1) { + const size_t kklen = std::min(CountK - k, BlkLen); + + const uint8_t* b_data = + QuantBDataCol + k_blk_idx * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const float b_s = QuantBScaleCol[k_blk_idx]; + const uint8_t b_z = + (QuantBZeroPointCol != nullptr) + ? ((k_blk_idx & 1) == 1) + ? QuantBZeroPointCol[k_blk_idx / 2] >> 4 + : QuantBZeroPointCol[k_blk_idx / 2] & 0x0F + : 8; + + for (size_t kk = 0; kk < kklen; ++kk) { + const uint8_t b_packed = b_data[kk / 2]; + const uint8_t b_byte = ((kk & 1) == 1) ? b_packed >> 4 : b_packed & 0x0F; + const float b_value = (b_byte - b_z) * b_s; + + Dst[(k + kk) * 16 + nn] = b_value; + } + } + + QuantBDataCol += BlockStrideQuantB * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + QuantBScaleCol += BlockStrideQuantB; + if (QuantBZeroPointCol != nullptr) { + QuantBZeroPointCol += MlasQNBitZeroPointsForBlksSizeInBytes(BlockStrideQuantB); + } + } + + // zero out any remaining columns + + if (nnlen < 16) { + for (size_t k = 0; k < CountK; ++k) { + std::fill_n(Dst + (k * 16) + nnlen, 16 - nnlen, 0.0f); + } + } + + Dst += CountK * 16; + } + }; + + impl0_reference(); +} + +#define SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(BlkBitWidth, BlkLen) \ + template <> \ + MLAS_FORCEINLINE void \ + MlasQNBitBlkDequantBForSgemm( \ + float* FpData, \ + const uint8_t* QuantBData, \ + const float* QuantBScale, \ + const uint8_t* QuantBZeroPoint, \ + size_t CountN, \ + size_t CountK, \ + size_t BlockStrideQuantB \ + ) \ + { \ + MlasQNBitBlkDequantBForSgemmNeon( \ + FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB \ + ); \ + } + +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 16) +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 32) +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 64) +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 128) +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 256) + +#undef SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM + +// +// Kernel dispatch structure definition. +// + +const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { + MLAS_SQNBIT_GEMM_DISPATCH d; + d.Operations[QuantVariant_BitWidth4_BlockSize16] = MlasSQNBitGemmOperation<4, 16, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + d.Operations[QuantVariant_BitWidth4_BlockSize32] = MlasSQNBitGemmOperation<4, 32, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + d.Operations[QuantVariant_BitWidth4_BlockSize64] = MlasSQNBitGemmOperation<4, 64, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + d.Operations[QuantVariant_BitWidth4_BlockSize128] = MlasSQNBitGemmOperation<4, 128, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + d.Operations[QuantVariant_BitWidth4_BlockSize256] = MlasSQNBitGemmOperation<4, 256, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + return d; +}(); diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc index 094ea1e24dd92..9c98ed6d3e114 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc @@ -338,8 +338,8 @@ std::optional IsSupportedGather(Graph& graph, Node& node, auto axis = static_cast(node.GetAttributes().at("axis").i()); axis = axis < 0 ? axis + data_rank : axis; size_t dim_size = static_cast(indices_shape->dim_size()); - bool is_single_value_1d_tensor = dim_size != 0 && (dim_size == 1 && utils::HasDimValue(indices_shape->dim(0)) && - indices_shape->dim(0).dim_value() == 1); + bool is_single_value_1d_tensor = dim_size == 1 && utils::HasDimValue(indices_shape->dim(0)) && + indices_shape->dim(0).dim_value() == 1; if (dim_size != 0 && !is_single_value_1d_tensor) { if (dim_size == 1 && utils::HasDimValue(data_shape->dim(axis)) && data_shape->dim(axis).dim_value() > indices_shape->dim(0).dim_value()) { diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.cc index 716b027068ba1..23f7c45fba4ba 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.cc @@ -3,6 +3,7 @@ #ifdef ENABLE_TRAINING +#include #include "core/optimizer/utils.h" #include "core/optimizer/compute_optimizer/upstream_reshape_actors.h" @@ -282,6 +283,23 @@ bool LayerNormalizationReshapeActor::PreCheck( return propagate_input_indices.size() > 0; } +bool LayerNormalizationReshapeActor::PostProcess( + Graph& /* graph */, Node& current_node, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) { + auto axis = static_cast(current_node.GetAttributes().at("axis").i()); + // When Reshape(from 3D to 2D, with the first two dimensions be merged) upstream a LayerNormalization, + // The axis attribute of LayerNormalization should be decreased by 1 if it is greater than 1. + if (axis > 1) { + auto new_axis = axis - 1; + auto& attributes = current_node.GetMutableAttributes(); + attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", static_cast(new_axis)); + } + return true; +} + template class SimplePointwiseReshapeActor; template class SimplePointwiseReshapeActor; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.h b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.h index 05bcbabe9ba4c..de50a56fd8781 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.h +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.h @@ -111,13 +111,11 @@ class UpStreamReshapeOperatorActorBase : public UpStreamOperatorActorBase { * So far, we don't have requirements to override PostProcess function. */ - bool PostProcess(Graph& /* graph */, Node& /* current_node */, const ReshapeInfo& /* info_without_node */, - const logging::Logger& /* logger */, - std::vector& /* propagate_input_indices */, - const std::unordered_map>& /* all_input_cmp_rets */, - const std::unordered_map& /* new_reshape_infos */) { - return true; - } + virtual bool PostProcess(Graph& /* graph */, Node& /* current_node */, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) = 0; }; // The inputs are broad-cast-able. The outputs should have the same shape (fully broadcasted shape) @@ -133,6 +131,14 @@ class SimplePointwiseReshapeActor : public UpStreamReshapeOperatorActorBase { std::vector& propagate_input_indices, std::unordered_map>& all_input_cmp_rets, std::function& shape_update_func) override; + + bool PostProcess(Graph& /* graph */, Node& /* current_node */, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) override { + return true; + } }; class MatMulReshapeActor : public UpStreamReshapeOperatorActorBase { @@ -145,6 +151,14 @@ class MatMulReshapeActor : public UpStreamReshapeOperatorActorBase { std::vector& propagate_input_indices, std::unordered_map>& all_input_cmp_rets, std::function& shape_update_func) override; + + bool PostProcess(Graph& /* graph */, Node& /* current_node */, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) override { + return true; + } }; class LayerNormalizationReshapeActor : public UpStreamReshapeOperatorActorBase { @@ -157,6 +171,12 @@ class LayerNormalizationReshapeActor : public UpStreamReshapeOperatorActorBase { std::vector& propagate_input_indices, std::unordered_map>& all_input_cmp_rets, std::function& shape_update_func) override; + + bool PostProcess(Graph& /* graph */, Node& current_node, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) override; }; /** diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index f46273f2680a9..e3a2f2d74c0d4 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -4,6 +4,7 @@ #include #include "core/optimizer/constant_folding.h" +#include "core/optimizer/initializer.h" #include "core/optimizer/utils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/optimizer_execution_frame.h" @@ -90,6 +91,45 @@ static bool ConstantFoldShapeNode(Graph& graph, Node& node) { return is_concrete_shape; // convert to constant if this is true } +// This function inlines the appropriate subgraph. It does not literally fold it. +static Status ConstantFoldIfNode(Graph& graph, Node& if_node, const logging::Logger& logger, bool& folded) { + folded = false; + // First, find out which subgraph to inline + // We need to fetch the constant argument. + assert(if_node.InputDefs().size() == 1); + const auto* condition_def = if_node.InputDefs()[0]; + + // We need to check if the condition is a constant. + constexpr bool check_outer_scope_true = true; + const ONNX_NAMESPACE::TensorProto* initializer = + graph.GetConstantInitializer(condition_def->Name(), check_outer_scope_true); + if (initializer == nullptr) { + return Status::OK(); + } + + // This is a boolean initializer with a single element. + Initializer condition{*initializer}; + ORT_RETURN_IF_NOT(condition.size() == 1, "If node condition initializer: `", condition_def->Name(), + "' is expected to have a single boolean element"); + + const bool condition_value = *condition.data(); + + auto status = graph.InlineIfSubgraph(condition_value, if_node, logger); + + if (!status.IsOK()) { + LOGS(logger, WARNING) << "Unable to constant fold. InlineIfSubgraph failed " + << " node '" << if_node.Name() << "': " + << status.ErrorMessage(); + return status; + } + + graph_utils::RemoveNodeOutputEdges(graph, if_node); + graph.RemoveNode(if_node.Index()); + + folded = true; + return status; +} + Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { bool have_updated_nodes = false; GraphViewer graph_viewer(graph); @@ -118,7 +158,20 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, } bool converted_to_constant = false; - if (node->OpType().compare("Shape") == 0) { + if (node->OpType().compare("If") == 0) { + // This process constant folds the If node only, + // but inlines the nodes of the corresponding branch graph. + // It does not convert the node to a constant in a common sense. + // We call it constant folding because the `If` node constant condition + // may enable us to inline the corresponding branch graph. + bool folded = false; + ORT_RETURN_IF_ERROR(ConstantFoldIfNode(graph, *node, logger, folded)); + if (folded) { + // Node removal is done within ConstantFoldIfNode() + modified = true; + have_updated_nodes = true; + } + } else if (node->OpType().compare("Shape") == 0) { converted_to_constant = ConstantFoldShapeNode(graph, *node); } else { InitializedTensorSet constant_inputs; diff --git a/onnxruntime/core/optimizer/gather_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc index b994028cbca13..4903bc1d6b961 100644 --- a/onnxruntime/core/optimizer/gather_fusion.cc +++ b/onnxruntime/core/optimizer/gather_fusion.cc @@ -9,7 +9,8 @@ namespace onnxruntime { -bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const { +bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, + int64_t& indices_n_dims) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { return false; @@ -53,6 +54,22 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + InlinedVector node_args; + for (auto node_arg : graph.GetInputs()) { + if (node_arg && graph.GetConsumerNodes(node_arg->Name()).size() > 1) { + node_args.push_back(node_arg); + } + } + + for (auto entry : graph.GetAllInitializedTensors()) { + if (graph.GetConsumerNodes(entry.first).size() > 1) { + auto node_arg = graph.GetNodeArg(entry.first); + if (node_arg) { + node_args.push_back(node_arg); + } + } + } + for (auto node_index : node_topology_list) { auto* p_node = graph.GetNode(node_index); if (p_node == nullptr) continue; // we removed the node as part of an earlier fusion @@ -73,7 +90,11 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le size_t output_count = node.GetOutputEdgesCount(); if (output_count <= 1) continue; - auto shape = node.MutableOutputDefs()[0]->Shape(); + node_args.push_back(node.OutputDefs()[0]); + } + + for (const NodeArg* node_arg : node_args) { + auto shape = node_arg->Shape(); if (!shape) continue; int64_t rank = static_cast(shape->dim_size()); @@ -81,11 +102,14 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le bool first_edge = true; int64_t split_axis = 0; int64_t indices_n_dims = -1; - InlinedVector gather_outputs(output_count, nullptr); + auto consumers = graph.GetConsumerNodes(node_arg->Name()); + size_t consumer_count = consumers.size(); + InlinedVector gather_outputs(consumer_count, nullptr); InlinedVector> nodes_to_fuse; - for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) { + for (auto consumer : consumers) { int64_t index, axis, dims; - if (!IsSupportedGather(graph, *it, index, axis, dims)) { + if (!consumer || consumer->InputDefs()[0] != node_arg || + !IsSupportedGather(graph, *consumer, index, axis, dims)) { can_fuse = false; break; } @@ -99,7 +123,7 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (axis < 0) axis += rank; if (first_edge) { auto dim = shape->dim(static_cast(axis)); - if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast(output_count)) { + if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast(consumer_count)) { can_fuse = false; break; } @@ -109,12 +133,12 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le can_fuse = false; break; } - if (index < 0) index += static_cast(output_count); - if (index < 0 || index >= static_cast(output_count) || gather_outputs[static_cast(index)]) { + if (index < 0) index += static_cast(consumer_count); + if (index < 0 || index >= static_cast(consumer_count) || gather_outputs[static_cast(index)]) { can_fuse = false; break; } - Node& gather_node = *graph.GetNode(it->Index()); + Node& gather_node = *graph.GetNode(consumer->Index()); nodes_to_fuse.emplace_back(gather_node); gather_outputs[static_cast(index)] = gather_node.MutableOutputDefs()[0]; } @@ -122,8 +146,8 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (!can_fuse) continue; ONNX_NAMESPACE::TypeProto split_output_type; - const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( - node.MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()); + const ONNX_NAMESPACE::TensorProto_DataType element_type = + static_cast(node_arg->TypeAsProto()->tensor_type().elem_type()); split_output_type.mutable_tensor_type()->set_elem_type(element_type); for (int64_t i = 0; i < rank; ++i) { if (i == split_axis) { @@ -136,16 +160,17 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le InlinedVector split_outputs; bool add_squeeze_node = indices_n_dims == 0; if (add_squeeze_node) { - for (size_t i = 0; i < output_count; ++i) { + for (size_t i = 0; i < consumer_count; ++i) { split_outputs.emplace_back( &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type)); } } - Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", - {node.MutableOutputDefs()[0]}, add_squeeze_node ? split_outputs : gather_outputs); + Node& split_node = + graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", + {graph.GetNodeArg(node_arg->Name())}, add_squeeze_node ? split_outputs : gather_outputs); split_node.AddAttribute("axis", split_axis); - split_node.SetExecutionProviderType(node.GetExecutionProviderType()); + split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); // Squeeze-11, Squeee-13, Split-13, Split-18 have different schemas. int onnx_opset_version = -1; @@ -155,16 +180,16 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (onnx_opset_version < 13) { if (add_squeeze_node) { - for (size_t i = 0; i < output_count; ++i) { + for (size_t i = 0; i < consumer_count; ++i) { Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", "Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]}); squeeze_node.AddAttribute("axes", std::vector{split_axis}); - squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType()); + squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); } } } else { if (onnx_opset_version >= 18) { - split_node.AddAttribute("num_outputs", static_cast(output_count)); + split_node.AddAttribute("num_outputs", static_cast(consumer_count)); } if (add_squeeze_node) { @@ -176,11 +201,11 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t)); NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto); - for (size_t i = 0; i < output_count; ++i) { + for (size_t i = 0; i < consumer_count; ++i) { Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", "Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]}); - squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType()); + squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); } } } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index f42766267b0f9..3d2a81ce7f8cd 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -87,12 +87,19 @@ std::vector WhereMoves() { MoveAll(q, ArgType::kOutput)}; return moves; } -QDQReplaceWithNew SplitReplacer() { +QDQReplaceWithNew SplitReplacer(bool has_split_as_input) { NTO::NodeLocation dq{NTO::NodeType::kInput, 0}; + NTO::NodeLocation target{NTO::NodeType::kTarget, 0}; NTO::NodeLocation q{NTO::NodeType::kOutput, 0}; - std::vector moves{ - MoveAndAppend(dq, ArgType::kInput, 0, ArgType::kInput), - MoveAll(q, ArgType::kOutput)}; + std::vector moves{MoveAndAppend(dq, ArgType::kInput, 0, ArgType::kInput)}; + + if (has_split_as_input) { + // Move the optional split input to the new node. + moves.push_back(MoveAndAppend(target, ArgType::kInput, 1, ArgType::kInput, true)); + } + + moves.push_back(MoveAll(q, ArgType::kOutput)); + return QDQReplaceWithNew(kOnnxDomain, "Split", std::move(moves)); } @@ -247,7 +254,12 @@ MatMulReplaceWithQLinear::MatMulReplaceWithQLinear() } Status SplitReplaceWithQuant::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { - return SplitReplacer().Run(graph, selected_nodes); + const auto& target_node = selected_nodes.Target(); + const auto& input_defs = target_node.InputDefs(); + + // The 'split' attribute became an optional input at opset 13. + bool has_split_as_input = target_node.SinceVersion() >= 13 && input_defs.size() == 2; + return SplitReplacer(has_split_as_input).Run(graph, selected_nodes); } Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 0e383c3031ca6..29178fe87f75c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -20,7 +20,7 @@ void SplitQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { const std::string action_name{"dropSplitQDQ"}; std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) - std::unique_ptr selector = std::make_unique(); + std::unique_ptr selector = std::make_unique(true /*req_equal_quant_params*/); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"Split", {}}}, std::move(selector), diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 5015e48fdb7b8..15b501c667046 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -253,7 +253,39 @@ void InputVariadicSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder builder.num_input_defs = 1; // set to 1 as the first input is variadic } -void OutputVariadicSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { +bool SplitNodeGroupSelector::Check(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, 1)) { + return false; + } + + auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { + return graph_viewer.GetConstantInitializer(initializer_name, true); + }; + + const Node& dq_node = *dq_nodes.front(); + int32_t dt_input = dq_node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + // All Q outputs should have same data type and (optionally) equal quantization parameters as the input. + for (size_t q_idx = 0; q_idx < q_nodes.size(); q_idx++) { + const Node& q_node = *q_nodes[q_idx]; + + if (dt_input != q_node.OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()) { + return false; + } + + if (req_equal_quant_params_ && + !IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath())) { + return false; + } + } + + return true; +} + +void SplitSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { builder.num_output_defs = 1; // set to 1 as the first output is variadic } @@ -443,7 +475,6 @@ bool InstanceAndLayerNormalizationNodeGroupSelector::Check(const GraphViewer& gr } int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - int32_t dt_scale = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); int32_t dt_bias = 0; bool has_bias = false; // bias is optional for LayerNorm @@ -453,9 +484,9 @@ bool InstanceAndLayerNormalizationNodeGroupSelector::Check(const GraphViewer& gr } int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - // Input, output, and scale need to be the same type. The bias is int32. + // Input, output, need to be the same type. The bias is int32. + // Scale can be different with input for a16w8 case return (dt_input == dt_output) && - (dt_input == dt_scale) && (has_bias ? dt_bias == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32 : true); } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index be7f7e0288eda..d0d7fb2c2af17 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -115,6 +115,24 @@ class VariadicNodeGroupSelector : public NodeGroupSelector { bool allow_16bit_; }; +// DQ node -> Split -> multiple Q nodes with equal quantization types. +// Optionally, the selector can require all input and output quantization parameters to be +// equal and constant. +class SplitNodeGroupSelector : public NodeGroupSelector { + public: + explicit SplitNodeGroupSelector(bool req_equal_quant_params = false) + : req_equal_quant_params_(req_equal_quant_params) {} + + private: + bool Check(const GraphViewer& graph_viewer, const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; + + bool req_equal_quant_params_; // If true, only selects a node group if the input and output + // quantization parameters are all equal/constant, which enables the + // optimizer to drop the Q/DQ ops if the group is assigned to the CPU EP. +}; + // DQ nodes for X, W and optionally B -> node -> Q class ConvNodeGroupSelector : public NodeGroupSelector { public: @@ -288,10 +306,11 @@ class InputVariadicSelector : public BaseSelector { void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; -// DQ -> node -> Variadic Q nodes -class OutputVariadicSelector : public BaseSelector { +// DQ -> Split -> variadic Q nodes +class SplitSelector : public BaseSelector { public: - OutputVariadicSelector() : BaseSelector(std::make_unique()) {} + SplitSelector(bool req_equal_quant_params = false) + : BaseSelector(std::make_unique(req_equal_quant_params)) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index 1a4d3a0c18151..e2aa25897ee06 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -27,6 +27,9 @@ void Selectors::RegisterSelector(const OpVersionsAndSelector::OpVersionsMap& ops } /* static methods to return different operator's OpVersionMap */ + +// These are operators that do not change the data and therefore the input DQ and +// output Q have the same scale and zero_point. static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { return {{"Gather", {}}, {"Reshape", {}}, @@ -35,7 +38,6 @@ static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { {"Transpose", {}}, {"MaxPool", {12}}, {"Resize", {}}, - {"Split", {}}, {"Squeeze", {}}, {"Unsqueeze", {}}, {"Tile", {}}}; @@ -97,6 +99,9 @@ static const OpVersionsAndSelector::OpVersionsMap GetVariadicOpVersionsMap() { {"Max", {}}, {"Min", {}}}; } +static const OpVersionsAndSelector::OpVersionsMap GetSplitOpVersionsMap() { + return {{"Split", {}}}; +} static const OpVersionsAndSelector::OpVersionsMap GetConvOpVersionsMap() { return {{"Conv", {}}}; } @@ -170,6 +175,13 @@ void RegisterVariadicSelectors(Selectors& qdq_selectors) { std::move(selector)); } +void RegisterSplitSelector(Selectors& qdq_selectors) { + /* register selectors for Split op */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetSplitOpVersionsMap(), + std::move(selector)); +} + void RegisterConvSelector(Selectors& qdq_selectors) { /* register selector for conv op */ std::unique_ptr selector = std::make_unique(); @@ -247,6 +259,7 @@ void SelectorManager::CreateSelectors() { RegisterUnarySelectors(qdq_selectors_); RegisterBinarySelectors(qdq_selectors_); RegisterVariadicSelectors(qdq_selectors_); + RegisterSplitSelector(qdq_selectors_); RegisterConvSelector(qdq_selectors_); RegisterConvTransposeSelector(qdq_selectors_); RegisterMatMulSelector(qdq_selectors_); diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 60e0b1c061a43..4a6743e9e5c52 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -169,6 +170,60 @@ Status CreateInputFeatureProvider(const std::unordered_map mlmultiarray_buffer_size) { + if (mlmultiarray_buffer == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "mlmultiarray_buffer has no data"); + } + + const size_t num_elements = array_info.count; + const auto onnx_data_type = tensor_info->data_type; + switch (onnx_data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + const auto output_data_byte_size = num_elements * sizeof(float); + ORT_RETURN_IF_NOT(!mlmultiarray_buffer_size || mlmultiarray_buffer_size == output_data_byte_size, + "CoreML output buffer size and expected output size differ"); + memcpy(tensor_buffer, mlmultiarray_buffer, output_data_byte_size); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT32: { + const auto output_data_byte_size = num_elements * sizeof(int32_t); + ORT_RETURN_IF_NOT(!mlmultiarray_buffer_size || mlmultiarray_buffer_size == output_data_byte_size, + "CoreML output buffer size and expected output size differ"); + memcpy(tensor_buffer, mlmultiarray_buffer, output_data_byte_size); + break; + } + // For this case, since Coreml Spec only uses int32 for model output while onnx provides + // int64 for model output data type. We are doing a type casting (int32 -> int64) here + // when copying the model to ORT + case ONNX_NAMESPACE::TensorProto_DataType_INT64: { + ORT_RETURN_IF_NOT(array_info.dataType == MLMultiArrayDataTypeInt32, + "CoreML output data type is not MLMultiArrayDataTypeInt32"); + ORT_RETURN_IF_NOT(!mlmultiarray_buffer_size || mlmultiarray_buffer_size == num_elements * sizeof(int32_t), + "CoreML output buffer size and expected output size differ"); + const auto model_output_span = gsl::span{static_cast(mlmultiarray_buffer), num_elements}; + const auto output_span = gsl::span{static_cast(tensor_buffer), num_elements}; + std::transform(model_output_span.begin(), model_output_span.end(), output_span.begin(), + [](int32_t v) { return static_cast(v); }); + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Output data type is not supported, actual type: ", onnx_data_type); + } + return Status::OK(); +} } // namespace NS_ASSUME_NONNULL_BEGIN @@ -298,9 +353,9 @@ - (Status)predict:(const std::unordered_map&)inputs return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "output_features has no value for ", output_name); } - auto* data = [output_value multiArrayValue]; + MLMultiArray* data = [output_value multiArrayValue]; - const auto coreml_static_output_shape = [&]() { + const auto coreml_static_output_shape = [data]() { InlinedVector result; result.reserve(data.shape.count); for (NSNumber* dim in data.shape) { @@ -324,41 +379,21 @@ - (Status)predict:(const std::unordered_map&)inputs ") do not match"); } - const void* model_output_buffer = data.dataPointer; - - if (model_output_buffer == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model_output_buffer has no data for ", output_name); - } - - const auto onnx_data_type = output_tensor_info.data_type; - switch (onnx_data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { - const auto output_data_byte_size = num_elements * sizeof(float); - memcpy(output_buffer, model_output_buffer, output_data_byte_size); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_INT32: { - const auto output_data_byte_size = num_elements * sizeof(int32_t); - memcpy(output_buffer, model_output_buffer, output_data_byte_size); - break; - } - // For this case, since Coreml Spec only uses int32 for model output while onnx provides - // int64 for model output data type. We are doing a type casting (int32 -> int64) here - // when copying the model to ORT - case ONNX_NAMESPACE::TensorProto_DataType_INT64: { - ORT_RETURN_IF_NOT(data.dataType == MLMultiArrayDataTypeInt32, - "CoreML output data type is not MLMultiArrayDataTypeInt32"); - - const auto model_output_span = gsl::span{static_cast(model_output_buffer), num_elements}; - const auto output_span = gsl::span{static_cast(output_buffer), num_elements}; - std::transform(model_output_span.begin(), model_output_span.end(), output_span.begin(), - [](int32_t v) { return static_cast(v); }); - break; - } - default: - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Output data type is not supported, actual type: ", onnx_data_type); + ORT_RETURN_IF_NOT(IsArrayContiguous(data), + "Non-contiguous output MLMultiArray is not currently supported"); + __block Status copy_status; + const auto* tensor_info = &output_tensor_info; + // `getBytesWithHandler` replaces deprecated `.dataPointer` on new versions + if (@available(macOS 12.3, iOS 15.4, *)) { + [data getBytesWithHandler:^(const void* bytes, NSInteger size) { + copy_status = CopyMLMultiArrayBuffer(bytes, output_buffer, data, tensor_info, size); + }]; + } else { + // disable size check as old API does not return buffer length + copy_status = CopyMLMultiArrayBuffer(data.dataPointer, output_buffer, data, tensor_info, std::nullopt); } + if (!copy_status.IsOK()) + return copy_status; } } } diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc index ce834e371fdef..3c83394fb0bf4 100644 --- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc @@ -688,21 +688,23 @@ FastReduceKind OptimizeShapeForFastReduce(gsl::span input_shape, return FastReduceKind::kNone; } -void ValidateCommonFastReduce(const Tensor* axes_tensor) { - ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); - ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, - "An axes tensor must be a vector tensor."); -} - // template bool CommonFastReduceCopy(OpKernelContext* ctx, TensorShapeVector& input_axes, bool noop_with_empty_axes) { if (ctx->InputCount() == 2) { // second input holds the axes. + // the argument is optional const Tensor* axes_tensor = ctx->Input(1); - ValidateCommonFastReduce(axes_tensor); - auto nDims = static_cast(axes_tensor->Shape()[0]); - const auto* data = axes_tensor->Data(); - input_axes.insert(input_axes.begin(), data, data + nDims); + + if (axes_tensor != nullptr) { + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a vector tensor."); + + const auto data_span = axes_tensor->DataAsSpan(); + input_axes.assign(data_span.begin(), data_span.end()); + } else { + input_axes.clear(); + } + if (input_axes.empty() && noop_with_empty_axes) { const Tensor* input = ctx->Input(0); auto* output = ctx->Output(0, input->Shape()); diff --git a/onnxruntime/core/providers/cuda/cuda_common.cc b/onnxruntime/core/providers/cuda/cuda_common.cc index 288ca8e97e34d..33f2938940e4d 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.cc +++ b/onnxruntime/core/providers/cuda/cuda_common.cc @@ -62,7 +62,8 @@ const char* CudaDataTypeToString(cudaDataType_t dt) { return "CUDA_R_16BF"; case CUDA_R_32F: return "CUDA_R_32F"; -#if (CUDA_VERSION >= 11080) +#if !defined(DISABLE_FLOAT8_TYPES) + // Note: CUDA_R_8F_E4M3 is defined with CUDA>=11.8 case CUDA_R_8F_E4M3: return "CUDA_R_8F_E4M3"; case CUDA_R_8F_E5M2: @@ -101,7 +102,7 @@ cudaDataType_t ToCudaDataType(int32_t element_type) { return CUDA_R_16F; case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: return CUDA_R_16BF; -#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080)) +#if !defined(DISABLE_FLOAT8_TYPES) case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: return CUDA_R_8F_E4M3; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index 9cd4e721ccab8..707099bac3ce0 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -58,6 +58,8 @@ class ToCudaType { } }; +#if !defined(DISABLE_FLOAT8_TYPES) + template <> class ToCudaType { public: @@ -76,6 +78,8 @@ class ToCudaType { } }; +#endif + inline bool CalculateFdmStrides(gsl::span p, const std::vector& dims) { int stride = 1; if (dims.empty() || p.size() < dims.size()) diff --git a/onnxruntime/core/providers/cuda/tensor/cast_op.cu b/onnxruntime/core/providers/cuda/tensor/cast_op.cu index 7542fb55757c6..f2c2e6d7458f9 100644 --- a/onnxruntime/core/providers/cuda/tensor/cast_op.cu +++ b/onnxruntime/core/providers/cuda/tensor/cast_op.cu @@ -141,7 +141,7 @@ struct CastSat { #endif -#endif +#endif // DISABLE_FLOAT8_TYPES template __global__ void CastKernelStd(const InT* input, OutT* output, CUDA_LONG N, CastStd cast) { diff --git a/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu b/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu index ad2a44793fe26..1da308811fa48 100644 --- a/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu +++ b/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu @@ -104,7 +104,7 @@ struct RoundSat { #endif -#endif +#endif // DISABLE_FLOAT8_TYPES template <> struct RoundStd { @@ -189,7 +189,7 @@ __global__ void QuantizeLinearKernelAxisSat(const InT* input, OutT* output, cons } } -#endif +#endif // DISABLE_FLOAT8_TYPES template Status CudaQuantizeLinearStd(cudaStream_t stream, const InT* input, OutT* output, const InT* scale, const OutT* zero_point, size_t num_of_element) { diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index d587424fe01f8..33f1f59e07f3f 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -118,6 +118,7 @@ static bool IsGPU(IDXCoreAdapter* compute_adapter) { return compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS); } +#ifdef ENABLE_NPU_ADAPTER_ENUMERATION static bool IsNPU(IDXCoreAdapter* compute_adapter) { // Only considering hardware adapters if (!IsHardwareAdapter(compute_adapter)) { @@ -125,6 +126,7 @@ static bool IsNPU(IDXCoreAdapter* compute_adapter) { } return !(compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS)); } +#endif enum class DeviceType { GPU, NPU, BadDevice }; @@ -134,10 +136,12 @@ static DeviceType FilterAdapterTypeQuery(IDXCoreAdapter* adapter, OrtDmlDeviceFi return DeviceType::GPU; } +#ifdef ENABLE_NPU_ADAPTER_ENUMERATION auto allow_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu; if (IsNPU(adapter) && allow_npus) { return DeviceType::NPU; } +#endif return DeviceType::BadDevice; } @@ -216,6 +220,7 @@ static void SortHeterogenousDXCoreAdapterList( return; } +#ifdef ENABLE_NPU_ADAPTER_ENUMERATION // When considering both GPUs and NPUs sort them by performance preference // of Default (Gpus first), HighPerformance (GPUs first), or LowPower (NPUs first) auto keep_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu; @@ -223,6 +228,7 @@ static void SortHeterogenousDXCoreAdapterList( if (!keep_npus || only_npus) { return; } +#endif struct SortingPolicy { // default is false because GPUs are considered higher priority in @@ -322,23 +328,26 @@ static std::optional ParsePerformancePreference(con static std::optional ParseFilter(const ProviderOptions& provider_options) { static const std::string Filter = "filter"; - static const std::string Any = "any"; static const std::string Gpu = "gpu"; +#ifdef ENABLE_NPU_ADAPTER_ENUMERATION + static const std::string Any = "any"; static const std::string Npu = "npu"; +#endif auto preference_it = provider_options.find(Filter); if (preference_it != provider_options.end()) { - if (preference_it->second == Any) { - return OrtDmlDeviceFilter::Any; - } - if (preference_it->second == Gpu) { return OrtDmlDeviceFilter::Gpu; } +#ifdef ENABLE_NPU_ADAPTER_ENUMERATION + if (preference_it->second == Any) { + return OrtDmlDeviceFilter::Any; + } if (preference_it->second == Npu) { return OrtDmlDeviceFilter::Npu; } +#endif ORT_THROW("Invalid Filter provided for DirectML EP device selection."); } diff --git a/onnxruntime/core/providers/js/operators/concat.cc b/onnxruntime/core/providers/js/operators/concat.cc index 3a6a7e1cafd7a..17c6b0466c3a5 100644 --- a/onnxruntime/core/providers/js/operators/concat.cc +++ b/onnxruntime/core/providers/js/operators/concat.cc @@ -12,7 +12,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 1, 3, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Concat); @@ -22,7 +23,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 4, 10, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Concat); @@ -32,7 +34,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 11, 12, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Concat); @@ -42,7 +45,8 @@ ONNX_OPERATOR_KERNEL_EX( 13, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), Concat); diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc index e9bbfabcf86bd..78563d30b0136 100644 --- a/onnxruntime/core/providers/js/operators/unary.cc +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -123,7 +123,7 @@ JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not) // activation -JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, ClipV10, min, 3.402823e+38f, max, -3.402823e+38f) +JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, Clip, min, 3.402823e+38f, max, -3.402823e+38f) JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10) JSEP_KERNEL_IMPL(Clip, Clip) ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index e0f060f758b2e..6d8c80bd2aaa1 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -56,8 +56,8 @@ Status BaseOpBuilder::ProcessInput(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } - OnnxInputInfo input_info = {}; - ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(input, input_info)); + TensorInfo input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input, input_info)); std::vector unpacked_tensor; if (input_info.is_initializer) { @@ -126,44 +126,38 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper, for (size_t output_i = 0; output_i < output_count; ++output_i) { const auto& output_name = outputs[output_i].node_arg.Name(); - Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT; - bool is_quantized_tensor = outputs[output_i].quant_param.has_value(); - utils::InitializeQuantizeParam(quantize_param, is_quantized_tensor); - - const auto* type_proto = outputs[output_i].node_arg.TypeAsProto(); - Qnn_DataType_t qnn_data_type = QNN_DATATYPE_UNDEFINED; - ORT_RETURN_IF_ERROR(utils::GetQnnDataType(is_quantized_tensor, type_proto, qnn_data_type)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.ProcessQuantizationParameter(outputs[output_i].quant_param, - quantize_param.scaleOffsetEncoding.scale, - quantize_param.scaleOffsetEncoding.offset), - "Cannot get quantization parameter"); - std::vector output_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(outputs[output_i].node_arg, output_shape), - "Cannot get shape"); - Qnn_DataType_t supported_qnn_data_type = GetSupportedOutputDataType(output_i, qnn_data_type); + TensorInfo output_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(outputs[output_i], output_info)); + + if (output_info.quant_param.encodingDefinition == QNN_DEFINITION_DEFINED) { + ORT_RETURN_IF_ERROR(OverrideOutputQuantParam(qnn_model_wrapper, node_unit, logger, input_names, + output_i, output_info.qnn_data_type, output_info.quant_param)); + } + + Qnn_DataType_t supported_qnn_data_type = GetSupportedOutputDataType(output_i, output_info.qnn_data_type); bool is_graph_output = qnn_model_wrapper.IsGraphOutput(output_name); - if (supported_qnn_data_type != qnn_data_type && is_graph_output && !do_op_validation) { + if (supported_qnn_data_type != output_info.qnn_data_type && is_graph_output && !do_op_validation) { std::string cast_node_name = output_name + "_ort_qnn_ep_cast"; std::string cast_input_name = output_name + "_ort_qnn_ep_aux"; - std::vector cast_output_shape = output_shape; + std::vector cast_output_shape = output_info.shape; QnnTensorWrapper cast_input_tensorwrapper(cast_input_name, QNN_TENSOR_TYPE_NATIVE, supported_qnn_data_type, - quantize_param, + output_info.quant_param, std::move(cast_output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_input_tensorwrapper)), "Failed to add tensor."); output_names.push_back(cast_input_name); cast_node_info_vec.push_back({cast_node_name, cast_input_name, output_name}); } else { - qnn_data_type = supported_qnn_data_type; + output_info.qnn_data_type = supported_qnn_data_type; output_names.push_back(output_name); } Qnn_TensorType_t tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, - qnn_data_type, - quantize_param, - std::move(output_shape)); + output_info.qnn_data_type, + output_info.quant_param, + std::move(output_info.shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); } @@ -188,6 +182,46 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } +Status BaseOpBuilder::SetOutputQParamEqualToInputIfNearlyEqual(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t input_index, + size_t output_index, + Qnn_DataType_t qnn_data_type, + Qnn_QuantizeParams_t& quant_param) const { + const QnnTensorWrapper& input_tensor_wrapper = qnn_model_wrapper.GetQnnTensorWrapper(input_names[input_index]); + ORT_RETURN_IF_NOT(input_tensor_wrapper.GetTensorDataType() == qnn_data_type, + "Input and output data types do not match"); + Qnn_QuantizeParams_t input_quant_param = GetQnnTensorQParams(input_tensor_wrapper.GetQnnTensor()); + + float scale_diff = 0.0f; + int32_t offset_diff = 0; + ORT_RETURN_IF_ERROR(CompareQnnQuantParams(quant_param, input_quant_param, scale_diff, offset_diff)); + constexpr float NEARLY_EQUAL_THRESHOLD = 1e-9f; + constexpr float WARN_THRESHOLD = 1e-6f; + + if (scale_diff != 0.0f && offset_diff == 0) { + if (scale_diff <= NEARLY_EQUAL_THRESHOLD) { + // Quantization params are nearly equal, so make them equal. This may allow QNN backends to employ certain graph + // optimizations that improve inference latency. + LOGS(logger, WARNING) << "QNN EP will override the output quantization parameters for " << node_unit.OpType() + << " operators to be equal to the input quantization parameters. Operator name: " + << node_unit.Name() << ", input_index: " << input_index << ", output index: " + << output_index << "."; + quant_param = input_quant_param; // Copy input quantization params to the output. + } else if (scale_diff <= WARN_THRESHOLD) { + // Quantization params are just outside of the "nearly equal" threshold, so warn user of potential latency + // degradation. + LOGS(logger, WARNING) << "The quantization parameters for the " << node_unit.OpType() << " operator '" + << node_unit.Name() << "' are not equal, which may result in latency degradation. " + << "input_index: " << input_index << ", output index: " << output_index << "."; + } + } + + return Status::OK(); +} + Status BaseOpBuilder::TransposeInitializer(const QnnModelWrapper& qnn_model_wrapper, const onnx::TensorProto& initializer, const std::vector& perm, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index 75a3a6ff2ff46..c979e599f96c4 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -38,6 +38,37 @@ class BaseOpBuilder : public IOpBuilder { return qnn_data_type; } + /** + * Allows operator builders that override this function to override output quantization parameters. + * Called by BaseOpBuilder::ProcessOutputs(). + * + * \param qnn_model_wrapper The QNN model that is being built. + * \param node_unit The node unit for which to return output information. + * \param logger The logger. + * \param input_names Names of all inputs consumed by this QNN node. + * \param output_index The index in node_unit.Outputs() of the output for which to return information. + * \param qnn_data_type The output's data type. + * \param quant_param The quantization parameter object that is overridden. + * \return An onnxruntime::Status object indicating failure or success. + */ + virtual Status OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + Qnn_QuantizeParams_t& quant_param) const ORT_MUST_USE_RESULT { + // Do nothing by default. Op builders like Split implement this function to override output quant params. + ORT_UNUSED_PARAMETER(qnn_model_wrapper); + ORT_UNUSED_PARAMETER(node_unit); + ORT_UNUSED_PARAMETER(logger); + ORT_UNUSED_PARAMETER(input_names); + ORT_UNUSED_PARAMETER(output_index); + ORT_UNUSED_PARAMETER(qnn_data_type); + ORT_UNUSED_PARAMETER(quant_param); + return Status::OK(); + } + virtual Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger, @@ -72,6 +103,15 @@ class BaseOpBuilder : public IOpBuilder { return node_name; } + Status SetOutputQParamEqualToInputIfNearlyEqual(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t input_index, + size_t output_index, + Qnn_DataType_t qnn_data_type, + Qnn_QuantizeParams_t& quant_param) const ORT_MUST_USE_RESULT; + static const std::string& GetQnnOpType(const std::string& onnx_op_type) { // TODO: Use QNN operator names defined in "QnnOpDef.h" static const std::unordered_map onnx_op_type_to_qnn_op_type = { diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc index 8febf09f0e26d..294aa659872c4 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc @@ -251,7 +251,7 @@ class BatchNormOpBuilder : public BaseOpBuilder { return Status::OK(); } - Status PreprocessMean(const OnnxInputInfo& mean_info, + Status PreprocessMean(const TensorInfo& mean_info, const bool is_npu_backend, const uint8_t* mean_raw_ptr, const size_t mean_raw_ptr_length, @@ -273,7 +273,7 @@ class BatchNormOpBuilder : public BaseOpBuilder { return Status::OK(); } - Status PreprocessStd(const OnnxInputInfo& var_info, + Status PreprocessStd(const TensorInfo& var_info, const bool is_npu_backend, const uint8_t* var_raw_ptr, const size_t var_raw_ptr_length, @@ -297,7 +297,7 @@ class BatchNormOpBuilder : public BaseOpBuilder { return Status::OK(); } - Status PreprocessScale(const OnnxInputInfo& scale_info, + Status PreprocessScale(const TensorInfo& scale_info, const bool is_npu_backend, const uint8_t* scale_raw_ptr, const size_t scale_raw_ptr_length, @@ -325,7 +325,7 @@ class BatchNormOpBuilder : public BaseOpBuilder { return Status::OK(); } - Status PreprocessBias(const OnnxInputInfo& bias_info, + Status PreprocessBias(const TensorInfo& bias_info, const bool is_npu_backend, const uint8_t* bias_raw_ptr, const size_t bias_raw_ptr_length, @@ -354,7 +354,7 @@ class BatchNormOpBuilder : public BaseOpBuilder { return Status::OK(); } - Status Postprocess(const OnnxInputInfo& info, + Status Postprocess(const TensorInfo& info, const bool is_npu_backend, const std::vector& double_tensor, const double rmax, @@ -476,14 +476,14 @@ Status BatchNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, { const std::string& scale_name = inputs[1].node_arg.Name(); const std::string& bias_name = inputs[2].node_arg.Name(); - OnnxInputInfo var_info = {}; - OnnxInputInfo mean_info = {}; - OnnxInputInfo scale_info = {}; - OnnxInputInfo bias_info = {}; - ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[1], scale_info)); - ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[2], bias_info)); - ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[3], mean_info)); - ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[4], var_info)); + TensorInfo var_info = {}; + TensorInfo mean_info = {}; + TensorInfo scale_info = {}; + TensorInfo bias_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], scale_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[2], bias_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[3], mean_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[4], var_info)); // scale, bias, mean, and var must be initializers ORT_RETURN_IF_NOT(scale_info.is_initializer, "scale must be initializers"); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc index df4c718949269..0a9f9889ad2d8 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc @@ -84,8 +84,8 @@ Status ClipOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra std::vector min_val_bytes; if (num_inputs > 1 && !inputs[1].node_arg.Name().empty()) { - OnnxInputInfo min_input_info = {}; - ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[1], min_input_info)); + TensorInfo min_input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], min_input_info)); ORT_RETURN_IF_NOT(min_input_info.qnn_data_type == qnn_data_type, "QNN EP: The 'min' input of the Clip operator must be of type float32."); assert(min_input_info.is_initializer); // Checked by ExplicitOpCheck(). @@ -106,8 +106,8 @@ Status ClipOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra std::vector max_val_bytes; if (num_inputs > 2 && !inputs[2].node_arg.Name().empty()) { - OnnxInputInfo max_input_info = {}; - ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[2], max_input_info)); + TensorInfo max_input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[2], max_input_info)); ORT_RETURN_IF_NOT(max_input_info.qnn_data_type == qnn_data_type, "QNN EP: The 'max' input of the Clip operator must of type float32."); assert(max_input_info.is_initializer); // Checked by ExplicitOpCheck(). diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc index e8c5b98129a1e..84b6cad9c41c1 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc @@ -175,8 +175,8 @@ Status ConvOpBuilder::ProcessConv2DInputs(QnnModelWrapper& qnn_model_wrapper, // { const std::string& input1_name = inputs[1].node_arg.Name(); - OnnxInputInfo input_info = {}; - ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[1], input_info)); + TensorInfo input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], input_info)); std::string actual_name = input_info.is_initializer ? input1_name : input1_name + "_ort_qnn_ep_transpose"; input_names.push_back(actual_name); @@ -267,8 +267,8 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, { const std::string& input0_name = inputs[0].node_arg.Name(); - OnnxInputInfo input0_info = {}; - ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[0], input0_info)); + TensorInfo input0_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[0], input0_info)); const std::string conv_input0_name = input0_info.is_initializer ? input0_name : input0_name + "_ort_qnn_ep_reshape"; @@ -318,8 +318,8 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, // { const std::string& input1_name = inputs[1].node_arg.Name(); - OnnxInputInfo input_info = {}; - ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[1], input_info)); + TensorInfo input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], input_info)); std::string conv_weight_input_name = input_info.is_initializer ? input1_name : input1_name + "_ort_qnn_ep_transpose"; input_names.push_back(conv_weight_input_name); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc index 7a44060b751cf..90e18e9fd0496 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc @@ -24,6 +24,13 @@ class ExpandOpBuilder : public BaseOpBuilder { const logging::Logger& logger, std::vector& input_names, bool do_op_validation) const override ORT_MUST_USE_RESULT; + Status OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + Qnn_QuantizeParams_t& quant_param) const override ORT_MUST_USE_RESULT; }; template @@ -131,6 +138,19 @@ Status ExpandOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } +Status ExpandOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + Qnn_QuantizeParams_t& quant_param) const { + // Force Expand output to use the same quantization parameters as the input if they are nearly equal. + // This enables the HTP backend to employ certain optimizations. + return SetOutputQParamEqualToInputIfNearlyEqual(qnn_model_wrapper, node_unit, logger, input_names, + 0 /*input_index*/, output_index, qnn_data_type, quant_param); +} + void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.AddOpBuilder(op_type, std::make_unique()); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc index c441fe331df3a..9f396a27369e7 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc @@ -168,6 +168,13 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w quantize_param.scaleOffsetEncoding.scale, quantize_param.scaleOffsetEncoding.offset), "Cannot get quantization parameter"); + if (is_quantized_tensor) { + // Make sure the output quantization parameters are equal to the input. + ORT_RETURN_IF_ERROR(SetOutputQParamEqualToInputIfNearlyEqual(qnn_model_wrapper, node_unit, logger, input_names, + 0 /*input_index*/, 0 /*output_index*/, qnn_data_type, + quantize_param)); + } + std::vector target_output_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(gather_output.node_arg, target_output_shape), "Cannot get shape"); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc index 6d39cd8102094..38172caa03768 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc @@ -94,8 +94,8 @@ Status InstanceNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, bool do_op_validation) const { const auto& inputs = node_unit.Inputs(); - OnnxInputInfo input0_info = {}; - ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[0], input0_info)); + TensorInfo input0_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[0], input0_info)); // HTP backend can only handle rank 3 inputs if the batch size is 1. If the batch size is not 1, // QNN EP must reshape the input and output to (N, 1, W, C) and process the InstanceNorm as rank 4. @@ -168,8 +168,8 @@ Status InstanceNormOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_m const auto& outputs = node_unit.Outputs(); - OnnxInputInfo output_info = {}; - ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(outputs[0], output_info)); + TensorInfo output_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(outputs[0], output_info)); // HTP backend can only handle rank 3 inputs/outputs if the batch size is 1. If the batch size is not 1, // QNN EP must reshape the input and output to (N, 1, W, C) and process the InstanceNorm as rank 4. diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc index 523095fac9aaf..d6752f76ef478 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc @@ -67,8 +67,8 @@ Status ProcessConstantValue(QnnModelWrapper& qnn_model_wrapper, std::vector& param_tensor_names, const NodeUnit& node_unit, const NodeUnitIODef& input) { - OnnxInputInfo input_info = {}; - ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(input, input_info)); + TensorInfo input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input, input_info)); std::vector unpacked_tensor; // Already confirmed constant_value input is initializer in ProcessInputs() ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_info.initializer_tensor, unpacked_tensor)); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index a44640b37ae36..872d9682b8355 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -29,6 +29,13 @@ class PoolOpBuilder : public BaseOpBuilder { std::vector&& input_names, const logging::Logger& logger, bool do_op_validation) const override ORT_MUST_USE_RESULT; + Status OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + Qnn_QuantizeParams_t& quant_param) const override ORT_MUST_USE_RESULT; private: Status SetCommonPoolParams(const NodeAttrHelper& node_helper, std::vector& filter_size, @@ -237,6 +244,23 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra return Status::OK(); } +Status PoolOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + Qnn_QuantizeParams_t& quant_param) const { + // Force MaxPool outputs to use the same quantization parameters as the input if they are nearly equal. + // This helps the HTP backend employ certain optimizations. + if (node_unit.OpType() == "MaxPool") { + return SetOutputQParamEqualToInputIfNearlyEqual(qnn_model_wrapper, node_unit, logger, input_names, + 0 /*input_index*/, output_index, qnn_data_type, quant_param); + } + + return Status::OK(); +} + void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.AddOpBuilder(op_type, std::make_unique()); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/reshape_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/reshape_op_builder.cc index 73ac81bfc8aef..4b06df6a0e632 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/reshape_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/reshape_op_builder.cc @@ -23,6 +23,13 @@ class ReshapeOpBuilder : public BaseOpBuilder { const logging::Logger& logger, std::vector& input_names, bool do_op_validation) const override ORT_MUST_USE_RESULT; + Status OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + Qnn_QuantizeParams_t& quant_param) const override ORT_MUST_USE_RESULT; }; Status ReshapeOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, @@ -44,6 +51,19 @@ Status ReshapeOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } +Status ReshapeOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + Qnn_QuantizeParams_t& quant_param) const { + // Force Reshape output to use the same quantization parameters as the input if nearly equal. + // This helps the HTP backend emply certain optimizations. + return SetOutputQParamEqualToInputIfNearlyEqual(qnn_model_wrapper, node_unit, logger, input_names, + 0 /*input_index*/, output_index, qnn_data_type, quant_param); +} + void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.AddOpBuilder(op_type, std::make_unique()); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc index 4039c4fbf8d70..cc620b7a86a18 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc @@ -42,6 +42,14 @@ class ResizeOpBuilder : public BaseOpBuilder { const logging::Logger& logger, bool do_op_validation) const override ORT_MUST_USE_RESULT; + Status OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + Qnn_QuantizeParams_t& quant_param) const override ORT_MUST_USE_RESULT; + private: // Info for each ONNX attribute of interest (attribute name + default value) static const OnnxAttrInfo onnx_mode_attr; @@ -311,6 +319,19 @@ Status ResizeOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w logger, do_op_validation, qnn_op_type); } +Status ResizeOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + Qnn_QuantizeParams_t& quant_param) const { + // Force Resize op's output to use the same quantization parameters as the input if nearly equal. + // This helps the HTP backend employ certain optimizations. + return SetOutputQParamEqualToInputIfNearlyEqual(qnn_model_wrapper, node_unit, logger, input_names, + 0 /*input_index*/, output_index, qnn_data_type, quant_param); +} + void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.AddOpBuilder(op_type, std::make_unique()); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 4022307b93ff8..fdc5317419c5b 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -22,11 +22,23 @@ class SimpleOpBuilder : public BaseOpBuilder { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SimpleOpBuilder); protected: + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, const logging::Logger& logger, bool do_op_validation) const override ORT_MUST_USE_RESULT; + Status OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + Qnn_QuantizeParams_t& quant_param) const override ORT_MUST_USE_RESULT; private: Status ExplicitOpCheck(const NodeUnit& node_unit) const; @@ -41,6 +53,90 @@ class SimpleOpBuilder : public BaseOpBuilder { static constexpr std::array gridsample_supported_padding_modes = {"zeros", "border", "reflection"}; }; +// Move to qnn_utils if it's re-usable +Status InsertConvertOp(QnnModelWrapper& qnn_model_wrapper, + const std::string& convert_input_name, + const std::string& convert_output_name, + Qnn_DataType_t input_qnn_data_type, + Qnn_DataType_t output_qnn_data_type, + int32_t input_offset, + float input_scale, + const std::vector& output_shape, + bool do_op_validation) { + // Assume input is already handled. + float qmin = 0.0f; + float qmax = 255.0f; + ORT_RETURN_IF_ERROR(qnn::utils::GetQminQmax(input_qnn_data_type, qmin, qmax)); + double value_min = qnn::utils::Dequantize(input_offset, input_scale, qmin); + double value_max = qnn::utils::Dequantize(input_offset, input_scale, qmax); + + Qnn_QuantizeParams_t convert_output_quant_param = QNN_QUANTIZE_PARAMS_INIT; + convert_output_quant_param.encodingDefinition = QNN_DEFINITION_DEFINED; + convert_output_quant_param.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + ORT_RETURN_IF_ERROR(qnn::utils::GetQuantParams(static_cast(value_min), + static_cast(value_max), + output_qnn_data_type, + convert_output_quant_param.scaleOffsetEncoding.scale, + convert_output_quant_param.scaleOffsetEncoding.offset)); + + std::vector output_shape_copy = output_shape; + QnnTensorWrapper convert_output_tensorwrapper(convert_output_name, + QNN_TENSOR_TYPE_NATIVE, + output_qnn_data_type, + convert_output_quant_param, + std::move(output_shape_copy)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(convert_output_tensorwrapper)), "Failed to add tensor."); + + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(convert_output_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + "Convert", + {convert_input_name}, + {convert_output_name}, + {}, + do_op_validation), + "Failed to add node."); + return Status::OK(); +} + +Status SimpleOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const { + const std::string& op_type = node_unit.OpType(); + ORT_RETURN_IF_ERROR(BaseOpBuilder::ProcessInputs(qnn_model_wrapper, node_unit, logger, input_names, do_op_validation)); + + if (op_type == "MatMul") { + const auto& inputs = node_unit.Inputs(); + TensorInfo input0_info = {}; + TensorInfo input1_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[0], input0_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], input1_info)); + // Need to insert Convert op if both inputs are dynamic inputs and are ufixed_16 + if (!input0_info.is_initializer && !input1_info.is_initializer && + input0_info.qnn_data_type == input1_info.qnn_data_type && + input0_info.qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) { + // insert Convert op after input1 + std::string convert_input_name = input_names.back(); + input_names.pop_back(); + const std::string& matmul_output_name = node_unit.Outputs()[0].node_arg.Name(); + std::string convert_output_name = convert_input_name + "_convert_" + matmul_output_name; + ORT_RETURN_IF_ERROR(InsertConvertOp(qnn_model_wrapper, + convert_input_name, + convert_output_name, + input1_info.qnn_data_type, + QNN_DATATYPE_UFIXED_POINT_8, + input1_info.quant_param.scaleOffsetEncoding.offset, + input1_info.quant_param.scaleOffsetEncoding.scale, + input1_info.shape, + do_op_validation)); + input_names.push_back(convert_output_name); + } + } + + return Status::OK(); +} + Status SimpleOpBuilder::ExplicitOpCheck(const NodeUnit& node_unit) const { const std::string& op_type = node_unit.OpType(); @@ -275,16 +371,6 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w ORT_RETURN_IF_ERROR(ProcessGridSampleAttributes(qnn_model_wrapper, node_unit, param_tensor_names)); } - if (op_type == "Sigmoid" || op_type == "Tanh") { - // QNN requires 16-bit QDQ Sigmoid and Tanh to use specific output scale and zero-point values - // regardless of floating-point range. - return ProcessSigmoidOrTanhOutput(qnn_model_wrapper, - node_unit, - std::move(input_names), - std::move(param_tensor_names), - logger, do_op_validation); - } - return ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), std::move(param_tensor_names), @@ -338,61 +424,43 @@ static bool OverrideQuantParams(const std::string& op_type, Qnn_DataType_t qnn_d return quant_params.offset != orig_offset || quant_params.scale != orig_scale; } -/** - * Processes the output for Sigmoid or Tanh operators and creates the corresponding QNN operator. - * These operator types are handled separately because QNN requires 16-bit QDQ Sigmoid and Tanh operators to use - * specific scale and zero-point values regardless of floating-point range. - * - * \param qnn_model_wrapper The QNN model wrapper object. - * \param node_unit The QDQ node unit for the Sigmoid or Tanh node. - * \param input_names List of input names. - * \param param_tensor_names List of param tensor names. - * \param logger Logger used to report information. - * \param do_op_validation True if the new QNN node should be validated. - */ -Status SimpleOpBuilder::ProcessSigmoidOrTanhOutput(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - std::vector&& param_tensor_names, - const logging::Logger& logger, - bool do_op_validation) const { +Status SimpleOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + Qnn_QuantizeParams_t& quant_param) const { + ORT_UNUSED_PARAMETER(input_names); const std::string& op_type = node_unit.OpType(); - const auto& output = node_unit.Outputs()[0]; - const std::string& output_name = output.node_arg.Name(); - - OnnxInputInfo output_info = {}; - - // TODO(adrianlizarraga): Rename GetOnnxInputInfo() since it can be used for outputs as well. - ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(output, output_info)); - if (output_info.quant_param.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { - if (OverrideQuantParams(op_type, output_info.qnn_data_type, output_info.quant_param.scaleOffsetEncoding)) { - const int32_t offset = output_info.quant_param.scaleOffsetEncoding.offset; - const float scale = output_info.quant_param.scaleOffsetEncoding.scale; - - LOGS(logger, VERBOSE) << "QNN requires that 16-bit quantized " << op_type << " operators use offset/scale values " - << "of <" << offset << ", " << scale << ">. QNN EP will override the original values for output " - << output_name; + // Override output quantization parameters for uint16 QDQ Sigmoid or Tanh. + // QNN requires 16-bit QDQ Sigmoid and Tanh to use specific output scale and zero-point values + // regardless of floating-point range. + if (op_type == "Sigmoid" || op_type == "Tanh") { + const auto& outputs = node_unit.Outputs(); + ORT_RETURN_IF_NOT(output_index < outputs.size(), + "Invalid output index in OverrideOutputQuantParam for op ", op_type.c_str()); + + const auto& output = node_unit.Outputs()[0]; + const std::string& output_name = output.node_arg.Name(); + + if (quant_param.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { + if (OverrideQuantParams(op_type, qnn_data_type, quant_param.scaleOffsetEncoding)) { + const int32_t offset = quant_param.scaleOffsetEncoding.offset; + const float scale = quant_param.scaleOffsetEncoding.scale; + + LOGS(logger, VERBOSE) << "QNN requires that 16-bit quantized " << op_type + << " operators use offset/scale values " + << "of <" << offset << ", " << scale + << ">. QNN EP will override the original values for output " << output_name; + ORT_RETURN_IF(qnn_model_wrapper.IsQnnTensorWrapperExist(output_name), + "QNN EP is unable to override output quantization parameters for ", op_type.c_str(), + " operator. Node name: ", node_unit.Name().c_str(), ", output name: ", output_name.c_str()); + } } } - ORT_RETURN_IF(qnn_model_wrapper.IsQnnTensorWrapperExist(output_name), - "QNN EP is unable to override output quantization parameters for ", op_type.c_str(), - " operator. Node name: ", node_unit.Name().c_str(), ", output name: ", output_name.c_str()); - Qnn_TensorType_t tensor_type = qnn_model_wrapper.IsGraphOutput(output_name) ? QNN_TENSOR_TYPE_APP_READ - : QNN_TENSOR_TYPE_NATIVE; - QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, output_info.qnn_data_type, output_info.quant_param, - std::move(output_info.shape)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(GetNodeName(node_unit), - QNN_OP_PACKAGE_NAME_QTI_AISW, - GetQnnOpType(op_type), - std::move(input_names), - {output_name}, - std::move(param_tensor_names), - do_op_validation), - "Failed to add node."); - return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc index 49d85d76e25a8..9059f7459200a 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc @@ -97,8 +97,8 @@ Status SoftmaxOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT; ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, axis)); - OnnxInputInfo input_info = {}; - ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[0], input_info)); + TensorInfo input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[0], input_info)); const size_t input_rank = input_info.shape.size(); // If the axis attribute refers to the last dimension, then process the input as normal. @@ -161,8 +161,8 @@ Status SoftmaxOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_ Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT; ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, axis)); - OnnxInputInfo output_info = {}; - ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(outputs[0], output_info)); + TensorInfo output_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(outputs[0], output_info)); const size_t output_rank = output_info.shape.size(); const bool axis_is_last_dim = static_cast(axis) == output_rank - 1; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc index 6812c223f7c90..f4b0d1ff59175 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc @@ -30,6 +30,14 @@ class SplitOpBuilder : public BaseOpBuilder { std::vector&& input_names, const logging::Logger& logger, bool do_op_validation) const override ORT_MUST_USE_RESULT; + + Status OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + Qnn_QuantizeParams_t& quant_param) const override ORT_MUST_USE_RESULT; }; Status SplitOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, @@ -121,6 +129,23 @@ Status SplitOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wr return Status::OK(); } +Status SplitOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + Qnn_QuantizeParams_t& quant_param) const { + // Force Split outputs to use the same quantization parameters as the input if nearly equal. + // This helps the HTP backend employ certain optimizations. + // + // The quantization tool assigns equal qparams to the input and outputs. + // However, Sigmoid/Tanh may override their output qparams, + // which requires us to explicitly handle this in case a Split is consumer of a Sigmoid/Tanh node. + return SetOutputQParamEqualToInputIfNearlyEqual(qnn_model_wrapper, node_unit, logger, input_names, + 0 /*input_index*/, output_index, qnn_data_type, quant_param); +} + void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.AddOpBuilder(op_type, std::make_unique()); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc index bf194a3c71337..721db9dd2670e 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/tile_op_builder.cc @@ -30,6 +30,14 @@ class TileOpBuilder : public BaseOpBuilder { std::vector&& input_names, const logging::Logger& logger, bool do_op_validation) const override ORT_MUST_USE_RESULT; + + Status OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + Qnn_QuantizeParams_t& quant_param) const override ORT_MUST_USE_RESULT; }; Status TileOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, @@ -86,6 +94,19 @@ Status TileOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra return Status::OK(); } +Status TileOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + const std::vector& input_names, + size_t output_index, + Qnn_DataType_t qnn_data_type, + Qnn_QuantizeParams_t& quant_param) const { + // Force the Tile operator output to use the same quantization parameters as the input if nearly equal. + // This helps the HTP backend employ certain optimizations. + return SetOutputQParamEqualToInputIfNearlyEqual(qnn_model_wrapper, node_unit, logger, input_names, + 0 /*input_index*/, output_index, qnn_data_type, quant_param); +} + void CreateTileOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.AddOpBuilder(op_type, std::make_unique()); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.cc b/onnxruntime/core/providers/qnn/builder/qnn_def.cc index f4eb7b2a2b158..a77ac16cf624b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.cc @@ -201,6 +201,30 @@ const Qnn_QuantizeParams_t& GetQnnTensorQParams(const Qnn_Tensor_t& qnn_tensor) } } +Status CompareQnnQuantParams(const Qnn_QuantizeParams_t& qparam0, const Qnn_QuantizeParams_t& qparam1, + float& scale_diff, int32_t& offset_diff) { + scale_diff = 0.0f; + offset_diff = 0; + + ORT_RETURN_IF_NOT((qparam0.encodingDefinition == qparam1.encodingDefinition && + qparam0.quantizationEncoding == qparam1.quantizationEncoding), + "Expected quantization parameters to be the same type."); + + if (qparam0.encodingDefinition == QNN_DEFINITION_DEFINED) { + switch (qparam0.quantizationEncoding) { + case QNN_QUANTIZATION_ENCODING_SCALE_OFFSET: { + scale_diff = std::abs(qparam0.scaleOffsetEncoding.scale - qparam1.scaleOffsetEncoding.scale); + offset_diff = std::abs(qparam0.scaleOffsetEncoding.offset - qparam1.scaleOffsetEncoding.offset); + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported quantization encoding: ", qparam0.quantizationEncoding); + } + } + + return Status::OK(); +} + bool CreateTensorInQnnGraph(const QNN_INTERFACE_VER_TYPE& qnn_interface, const Qnn_GraphHandle_t& graph, const std::string& node_name, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index 66154fcf346ee..f6a3b1bd360ec 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -122,6 +122,20 @@ uint32_t* GetQnnTensorDims(const Qnn_Tensor_t& qnn_tensor); const Qnn_ClientBuffer_t& GetQnnTensorClientBuf(const Qnn_Tensor_t& qnn_tensor); const Qnn_QuantizeParams_t& GetQnnTensorQParams(const Qnn_Tensor_t& qnn_tensor); +/** + * Compares two sets of quantization parameters. Sets the parameters `scale_diff` and `offset_diff` + * to the absolute differences. Returns an error status if the quantization parameters are not + * of the same type, or if the type is not supported. + * + * \param qparam0 The first set of quantization parameters. + * \param qparam1 The second set of quantization parameters. + * \param scale_diff Set to the absolute value of the difference in scale value. + * \param offset_diff Set to the absolute value of the difference in offset value. + * \return Status indicating success. + */ +Status CompareQnnQuantParams(const Qnn_QuantizeParams_t& qparam0, const Qnn_QuantizeParams_t& qparam1, + float& max_scale_diff, int32_t& max_offset_diff); + // TODO: split out separate files for Wrappers class QnnTensorWrapper { public: diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 9d339387b0a43..a422434205c68 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -365,33 +365,33 @@ bool QnnModelWrapper::ProcessQuantizationParameter(const std::optional shape; Qnn_DataType_t qnn_data_type; Qnn_QuantizeParams_t quant_param; @@ -117,8 +117,7 @@ class QnnModelWrapper { return input_index_map_.find(tensor_name) != input_index_map_.end(); } - // TODO(hecli) rename to GetTensorInfo - Status GetOnnxInputInfo(const NodeUnitIODef& input, OnnxInputInfo& input_info) const; + Status GetTensorInfo(const NodeUnitIODef& input, TensorInfo& input_info) const; Status AddReshapeNode(const std::string& input_name, const std::string& output_name, diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 7f5ab3a772305..79f84864a5788 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1194,6 +1194,11 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() { } } + if (external_stream_) { + ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasDestroy(external_cublas_handle_))); + ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnDestroy(external_cudnn_handle_))); + } + if (!external_stream_ && stream_) { ORT_IGNORE_RETURN_VALUE(CUDA_CALL(cudaStreamDestroy(stream_))); } @@ -1272,6 +1277,20 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { return Status::OK(); } +// Get the pointer to the IBuilder instance. +// Note: This function is not thread safe. Calls to this function from different threads must be serialized +// even though it doesn't make sense to have multiple threads initializing the same inference session. +nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const { + if (!builder_) { + TensorrtLogger& trt_logger = GetTensorrtLogger(); + { + auto lock = GetApiLock(); + builder_ = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + } + } + return builder_.get(); +} + void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector& custom_op_domain_list) const { if (info_.custom_op_domain_list.empty()) { common::Status status = CreateTensorRTCustomOpDomainList(info_); @@ -1562,7 +1581,9 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect } } - if (has_control_flow_op) { + // Only if the newly built graph has control flow op as well as it has parent node, + // it needs to handle outer scope values before calling graph.Resolve(). + if (has_control_flow_op && graph.ParentNode()) { LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Handle outer scope values for the subgraph " << graph_build.Name(); BuildSubGraphContext(graph_build); SetGraphOuterScopeValuesAndInputs(graph_build, graph.GetGraph()); @@ -1631,7 +1652,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect // Get supported node list recursively SubGraphCollection_t parser_nodes_list; TensorrtLogger& trt_logger = GetTensorrtLogger(); - auto trt_builder = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + auto trt_builder = GetBuilder(); const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(explicitBatch)); @@ -1808,6 +1829,10 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, if (sub_graphs.size() != 0) { bool all_subgraphs_are_supported = true; for (auto sub_graph : sub_graphs) { + // TRT EP should consider the empty subgraph is fully supported by TRT. + if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) { + continue; + } if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) { all_subgraphs_are_supported = false; break; @@ -1875,27 +1900,33 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, auto sub_graphs = graph.ParentNode()->GetSubgraphs(); for (auto sub_graph : sub_graphs) { if (sub_graph.get() != &graph.GetGraph()) { - auto sub_graph_veiwer = sub_graph->CreateGraphViewer(); - const int number_of_ort_subgraph_nodes = sub_graph_veiwer->NumberOfNodes(); + auto sub_graph_viewer = sub_graph->CreateGraphViewer(); + const int number_of_ort_subgraph_nodes = sub_graph_viewer->NumberOfNodes(); std::vector subgraph_nodes_vector(number_of_ort_subgraph_nodes); std::iota(std::begin(subgraph_nodes_vector), std::end(subgraph_nodes_vector), 0); SubGraphCollection_t parser_subgraph_nodes_vector = {{subgraph_nodes_vector, false}}; bool subgraph_early_termination = false; - // Another subgraph of "If" control flow has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP. - if (AllNodesAssignedToSpecificEP(*sub_graph_veiwer, kTensorrtExecutionProvider)) { + // Another subgraph of "If" control flow op has no nodes. + // In this case, TRT EP should consider this empty subgraph is fully supported by TRT. + if (sub_graph_viewer->NumberOfNodes() == 0) { + all_subgraphs_are_supported = true; + break; + } + // Another subgraph of "If" control flow op has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP. + else if (AllNodesAssignedToSpecificEP(*sub_graph_viewer, kTensorrtExecutionProvider)) { all_subgraphs_are_supported = true; break; } // Another subgraph of "If" control flow has been parsed by GetCapability and not all subgraph's nodes assigned to TRT EP. // (Note: GetExecutionProviderType() returns "" meaning node has not yet been assigned to any EPs) - else if (!AllNodesAssignedToSpecificEP(*sub_graph_veiwer, "")) { + else if (!AllNodesAssignedToSpecificEP(*sub_graph_viewer, "")) { all_subgraphs_are_supported = false; break; } // Another subgraph of "If" control flow has not yet been parsed by GetCapability. - subgraph_supported_nodes_vector = GetSupportedList(parser_subgraph_nodes_vector, 0, max_partition_iterations_, *sub_graph_veiwer, &subgraph_early_termination); + subgraph_supported_nodes_vector = GetSupportedList(parser_subgraph_nodes_vector, 0, max_partition_iterations_, *sub_graph_viewer, &subgraph_early_termination); all_subgraphs_are_supported = IsSubGraphFullySupported(subgraph_supported_nodes_vector, number_of_ort_subgraph_nodes); break; } @@ -1983,7 +2014,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(nvinfer1::createInferBuilder(trt_logger)); + auto trt_builder = GetBuilder(); const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(explicitBatch)); auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); @@ -2436,7 +2467,6 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorallocate_func, context->release_func, context->allocator_handle, context->node_name, - &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], + *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(), + &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, @@ -2488,7 +2518,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorsync_stream_after_enqueue; auto fused_node_name = trt_state->fused_node_name; auto& shape_ranges = trt_state->input_shape_ranges; - auto trt_builder = trt_state->builder->get(); + auto trt_builder = trt_state->builder; auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); auto trt_profiles = trt_state->profiles; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index cda08715ea009..a945d219088aa 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -105,10 +105,10 @@ struct TensorrtFuncState { DestroyFunc test_release_func = nullptr; AllocatorHandle allocator = nullptr; std::string fused_node_name; + nvinfer1::IBuilder* builder; tensorrt_ptr::unique_pointer* parser = nullptr; std::unique_ptr* engine = nullptr; std::unique_ptr* context = nullptr; - std::unique_ptr* builder = nullptr; std::unique_ptr* network = nullptr; std::vector> input_info; std::vector> output_info; @@ -245,6 +245,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; mutable std::unordered_map> subgraph_context_map_; + mutable std::unique_ptr builder_; + // Following maps that hold TRT objects will be accessible by different threads if ORT is using multithreading. // In general, TensorRT objects are not thread safe; accesses to an object from different threads must be serialized by the client. // But there are still some thread safe operations, please see here https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading @@ -456,5 +458,11 @@ class TensorrtExecutionProvider : public IExecutionProvider { void CaptureBegin(); void CaptureEnd(); void IncrementRegularRunCountBeforeGraphCapture(); + + /** + * Get the pointer to the IBuilder instance. + * This function only creates the instance at the first time it's being called." + */ + nvinfer1::IBuilder* GetBuilder() const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc index ecc72b1c65476..92fa101118506 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc @@ -3,10 +3,36 @@ #include "core/providers/shared_library/provider_api.h" #include "tensorrt_execution_provider.h" +#include "core/framework/murmurhash3.h" #include namespace onnxruntime { +namespace { +// Get unique graph name based on graph's name and all nodes' name +std::string GetUniqueGraphName(const Graph& graph) { + HashValue model_hash = 0; + uint32_t hash[4] = {0, 0, 0, 0}; + + auto hash_str = [&hash](const std::string& str) { + MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); + }; + + // Hash all nodes' name + for (int i = 0; i < graph.MaxNodeIndex(); ++i) { + auto node = graph.GetNode(i); + if (node == nullptr) { + continue; + } + hash_str(node->Name()); + } + + model_hash = hash[0] | (uint64_t(hash[1]) << 32); + + return graph.Name() + "_" + std::to_string(model_hash); +} +} // namespace + // The newly-built graph has not yet being resolved by Graph::Resolve(), so we can't leverage // Graph::ResolveContext::IsInputInitializerOrOutput(). We have to implement this fuction again. bool TensorrtExecutionProvider::IsInputInitializerOrOutput(const Graph& graph, @@ -31,10 +57,11 @@ bool TensorrtExecutionProvider::IsOuterScopeValue(const Graph& graph, // Graph::ResolveContext::IsLocalValue(). We have to implement this function again. bool TensorrtExecutionProvider::IsLocalValue(const Graph& graph, const std::string& name) const { - if (subgraph_context_map_.find(graph.Name()) == subgraph_context_map_.end()) { + std::string unique_graph_name = GetUniqueGraphName(graph); + if (subgraph_context_map_.find(unique_graph_name) == subgraph_context_map_.end()) { return false; } - SubGraphContext* context = subgraph_context_map_.at(graph.Name()).get(); + SubGraphContext* context = subgraph_context_map_.at(unique_graph_name).get(); return context->output_args.find(name) != context->output_args.cend() || context->inputs_and_initializers.find(name) != context->inputs_and_initializers.cend(); } @@ -59,13 +86,15 @@ void TensorrtExecutionProvider::BuildSubGraphContext(const Graph& graph) const { } } + std::string unique_graph_name = GetUniqueGraphName(graph); + // Subgraph context has been built before, no need to do it again - if (subgraph_context_map_.find(graph.Name()) != subgraph_context_map_.end()) { + if (subgraph_context_map_.find(unique_graph_name) != subgraph_context_map_.end()) { return; } - subgraph_context_map_.emplace(graph.Name(), std::make_unique()); - SubGraphContext* context = subgraph_context_map_.at(graph.Name()).get(); + subgraph_context_map_.emplace(unique_graph_name, std::make_unique()); + SubGraphContext* context = subgraph_context_map_.at(unique_graph_name).get(); // Collect all nodes' outputs and nodes' name for (int i = 0; i < graph.MaxNodeIndex(); ++i) { @@ -138,13 +167,14 @@ void TensorrtExecutionProvider::SetGraphOuterScopeValuesAndInputs(Graph& graph_b while (top_level_graph->MutableParentGraph()) { top_level_graph = top_level_graph->MutableParentGraph(); } - if (subgraph_context_map_.find(top_level_graph->Name()) == subgraph_context_map_.end()) { + std::string unique_graph_name = GetUniqueGraphName(*top_level_graph); + if (subgraph_context_map_.find(unique_graph_name) == subgraph_context_map_.end()) { LOGS_DEFAULT(ERROR) << "[TensorRT EP] Can't find top-level graph context. \ Please check BuildSubGraphContext() has built the graph context correctly."; return; } - SubGraphContext* context = subgraph_context_map_.at(top_level_graph->Name()).get(); + SubGraphContext* context = subgraph_context_map_.at(unique_graph_name).get(); LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Subgraph name is " << graph_build.Name(); LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Its parent node is " << graph.ParentNode()->Name(); @@ -197,12 +227,13 @@ void TensorrtExecutionProvider::SetGraphOuterScopeValuesAndInputs(Graph& graph_b void TensorrtExecutionProvider::SetAllGraphInputs(Graph& graph) const { // If ORT TRT doesn't manully set graph input in TensorrtExecutionProvider::SetGraphOuterScopeValuesAndInputs(), // Graph::Resolve() will help set graph inputs in Graph::SetGraphInputsOutputs(), so no need to set graph inputs here. - if (subgraph_context_map_.find(graph.Name()) == subgraph_context_map_.end() || - subgraph_context_map_[graph.Name()].get()->manually_added_graph_inputs.size() == 0) { + std::string unique_graph_name = GetUniqueGraphName(graph); + if (subgraph_context_map_.find(unique_graph_name) == subgraph_context_map_.end() || + subgraph_context_map_[unique_graph_name].get()->manually_added_graph_inputs.size() == 0) { return; } - SubGraphContext* context = subgraph_context_map_[graph.Name()].get(); + SubGraphContext* context = subgraph_context_map_[unique_graph_name].get(); std::vector graph_inputs_including_initializers; std::unordered_set graph_inputs_including_initializers_set; diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 38266f566e6e1..d34cb7e362446 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -85,7 +85,7 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v const auto* node(graph_viewer.GetNode(node_idx)); bool supported = false; // Firstly check if platform supports the WebNN op. - if (CheckSingleOp(node->OpType(), wnn_builder_)) { + if (CheckSingleOp(node->OpType(), wnn_builder_, device_type)) { LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser"; supported = IsNodeSupported(*node, graph_viewer, device_type, logger); } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 46c456556e016..28b54b9c9cf8d 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -30,6 +30,11 @@ enum class WebnnDeviceType { GPU, }; +typedef struct { + std::string opName; + bool isCpuSupported; // The WebNN CPU backend XNNPack supports it (not about the CPU EP). +} WebnnOpInfo; + bool GetShape(const NodeArg& node_arg, std::vector& shape, const logging::Logger& logger); template @@ -128,88 +133,107 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v const emscripten::val& wnn_builder_, const WebnnDeviceType device_type, const logging::Logger& logger); -static const InlinedHashMap op_map = { - {"Abs", "abs"}, - {"Add", "add"}, - {"ArgMax", "argMax"}, - {"ArgMin", "argMin"}, - {"AveragePool", "averagePool2d"}, - {"BatchNormalization", "meanVarianceNormalization"}, - {"Cast", "cast"}, - {"Ceil", "ceil"}, - {"Clip", "clamp"}, - {"Concat", "concat"}, - {"Conv", "conv2d"}, - {"ConvTranspose", "convTranspose2d"}, - {"Cos", "cos"}, - {"Div", "div"}, - {"Elu", "elu"}, - {"Equal", "equal"}, - {"Erf", "erf"}, - {"Exp", "exp"}, - {"Expand", "expand"}, - {"Flatten", "flattenTo2d"}, - {"Floor", "floor"}, - {"Gather", "gather"}, - {"Gemm", "gemm"}, - {"GlobalAveragePool", "averagePool2d"}, - {"GlobalMaxPool", "maxPool2d"}, - {"GlobalLpPool", "l2Pool2d"}, - {"Greater", "greater"}, - {"GroupNormalization", "meanVarianceNormalization"}, - {"HardSigmoid", "hardSigmoid"}, - {"HardSwish", "hardSwish"}, - {"Identity", "identity"}, - {"InstanceNormalization", "meanVarianceNormalization"}, - {"LayerNormalization", "meanVarianceNormalization"}, - {"LeakyRelu", "leakyRelu"}, - {"Less", "lesser"}, - {"Log", "log"}, - {"LpPool", "l2Pool2d"}, - {"MatMul", "matmul"}, - {"Max", "max"}, - {"MaxPool", "maxPool2d"}, - {"Min", "min"}, - {"Mul", "mul"}, - {"Neg", "neg"}, - {"Not", "logicalNot"}, - {"Pad", "pad"}, - {"Pow", "pow"}, - {"PRelu", "prelu"}, - {"Reciprocal", "reciprocal"}, - {"ReduceL1", "reduceL1"}, - {"ReduceL2", "reduceL2"}, - {"ReduceLogSum", "reduceLogSum"}, - {"ReduceLogSumExp", "reduceLogSumExp"}, - {"ReduceMax", "reduceMax"}, - {"ReduceMean", "reduceMean"}, - {"ReduceMin", "reduceMin"}, - {"ReduceProd", "reduceProduct"}, - {"ReduceSum", "reduceSum"}, - {"ReduceSumSquare", "reduceSumSquare"}, - {"Relu", "relu"}, - {"Reshape", "reshape"}, - {"Resize", "resample2d"}, - {"Shape", "slice"}, - {"Sigmoid", "sigmoid"}, - {"Softplus", "softplus"}, - {"Softsign", "softsign"}, - {"Sin", "sin"}, - {"Slice", "slice"}, - {"Softmax", "softmax"}, - {"Split", "split"}, - {"Sqrt", "sqrt"}, - {"Squeeze", "squeeze"}, - {"Sub", "sub"}, - {"Tan", "tan"}, - {"Tanh", "tanh"}, - {"Transpose", "transpose"}, - {"Unsqueeze", "unsqueeze"}, - {"Where", "elementwiseIf"}, +static const InlinedHashMap op_map = { + {"Abs", {"abs", true}}, + {"Add", {"add", true}}, + {"ArgMax", {"argMax", false}}, + {"ArgMin", {"argMin", false}}, + {"AveragePool", {"averagePool2d", true}}, + {"BatchNormalization", {"meanVarianceNormalization", false}}, + {"Cast", {"cast", false}}, + {"Ceil", {"ceil", true}}, + {"Clip", {"clamp", true}}, + {"Concat", {"concat", true}}, + {"Conv", {"conv2d", true}}, + {"ConvTranspose", {"convTranspose2d", true}}, + {"Cos", {"cos", false}}, + {"Div", {"div", true}}, + {"Elu", {"elu", true}}, + {"Equal", {"equal", false}}, + {"Erf", {"erf", false}}, + {"Exp", {"exp", false}}, + {"Expand", {"expand", false}}, + {"Flatten", {"flattenTo2d", false}}, + {"Floor", {"floor", true}}, + {"Gather", {"gather", false}}, + {"Gemm", {"gemm", true}}, + {"GlobalAveragePool", {"averagePool2d", true}}, + {"GlobalMaxPool", {"maxPool2d", true}}, + {"GlobalLpPool", {"l2Pool2d", false}}, + {"Greater", {"greater", false}}, + {"GreaterOrEqual", {"greaterOrEqual", false}}, + {"GroupNormalization", {"meanVarianceNormalization", false}}, + {"HardSigmoid", {"hardSigmoid", false}}, + {"HardSwish", {"hardSwish", true}}, + {"Identity", {"identity", false}}, + {"InstanceNormalization", {"meanVarianceNormalization", false}}, + {"LayerNormalization", {"meanVarianceNormalization", false}}, + {"LeakyRelu", {"leakyRelu", true}}, + {"Less", {"lesser", false}}, + {"LessOrEqual", {"lesserOrEqual", false}}, + {"Log", {"log", false}}, + {"LpPool", {"l2Pool2d", false}}, + {"MatMul", {"matmul", false}}, + {"Max", {"max", true}}, + {"MaxPool", {"maxPool2d", true}}, + {"Min", {"min", true}}, + {"Mul", {"mul", true}}, + {"Neg", {"neg", true}}, + {"Not", {"logicalNot", false}}, + {"Pad", {"pad", true}}, + {"Pow", {"pow", true}}, + {"PRelu", {"prelu", true}}, + {"Reciprocal", {"reciprocal", false}}, + {"ReduceL1", {"reduceL1", false}}, + {"ReduceL2", {"reduceL2", false}}, + {"ReduceLogSum", {"reduceLogSum", false}}, + {"ReduceLogSumExp", {"reduceLogSumExp", false}}, + {"ReduceMax", {"reduceMax", false}}, + {"ReduceMean", {"reduceMean", true}}, + {"ReduceMin", {"reduceMin", false}}, + {"ReduceProd", {"reduceProduct", false}}, + {"ReduceSum", {"reduceSum", true}}, + {"ReduceSumSquare", {"reduceSumSquare", false}}, + {"Relu", {"relu", true}}, + {"Reshape", {"reshape", true}}, + {"Resize", {"resample2d", true}}, + {"Shape", {"slice", true}}, + {"Sigmoid", {"sigmoid", true}}, + {"Softplus", {"softplus", false}}, + {"Softsign", {"softsign", false}}, + {"Sin", {"sin", false}}, + {"Slice", {"slice", true}}, + {"Softmax", {"softmax", true}}, + {"Split", {"split", true}}, + {"Sqrt", {"sqrt", false}}, + {"Squeeze", {"squeeze", false}}, + {"Sub", {"sub", true}}, + {"Tan", {"tan", false}}, + {"Tanh", {"tanh", true}}, + {"Transpose", {"transpose", true}}, + {"Unsqueeze", {"unsqueeze", false}}, + {"Where", {"elementwiseIf", false}}, }; -inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder_) { - return op_map.find(op_type) != op_map.end() && wnn_builder_[op_map.find(op_type)->second].as(); +inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder_, + const WebnnDeviceType device_type) { + // Returns false if the op_type is not listed in the op_map. + if (op_map.find(op_type) == op_map.end()) { + return false; + } + // Returns false if the WebNN op has not been implemented in MLGraphBuilder in current browser. + if (!wnn_builder_[op_map.find(op_type)->second.opName].as()) { + return false; + } + // The current WebNN CPU (XNNPack) backend supports a limited op list, and we'd rather + // fall back early to the ORT CPU EP rather than fail in the WebNN "cpu" deviceType. + // This is a workaround because the op may be included in MLGraphBuilder for DirectML + // backend but without XNNPack implementation in Chromium. + if (!op_map.find(op_type)->second.isCpuSupported) { + return false; + } + + return true; } constexpr std::array supported_cpu_data_types = { diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index 4cb49d8f8cd3a..c8f58fa98635f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -35,8 +35,12 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons output = model_builder.GetBuilder().call("equal", input0, input1); } else if (op_type == "Greater") { output = model_builder.GetBuilder().call("greater", input0, input1); + } else if (op_type == "GreaterOrEqual") { + output = model_builder.GetBuilder().call("greaterOrEqual", input0, input1); } else if (op_type == "Less") { output = model_builder.GetBuilder().call("lesser", input0, input1); + } else if (op_type == "LessOrEqual") { + output = model_builder.GetBuilder().call("lesserOrEqual", input0, input1); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "LogicalOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); @@ -54,7 +58,9 @@ void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& { "Equal", "Greater", + "GreaterOrEqual", "Less", + "LessOrEqual", }; op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index 8778bb2414108..e48cf35012652 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -114,6 +114,22 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, if (!GetShape(*input_defs[0], input_shape, logger)) { return false; } + + if (input_defs.size() < 3) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 3 inputs (data, starts, ends) but got " + << input_defs.size(); + return false; + } + + // Inputs: starts, ends, axes, and steps must be constant initializers if present. + for (size_t i = 1; i < input_defs.size(); i++) { + if (!Contains(initializers, input_defs[i]->Name())) { + LOGS(logger, VERBOSE) << "Input [" << input_defs[i]->Name() << "] of " << op_type + << " [" << name << "] must be known as initializer"; + return false; + } + } + if (input_defs.size() == 5) { // Check steps. const auto& steps_tensor = *initializers.at(input_defs[4]->Name()); std::vector unpacked_tensor; @@ -140,18 +156,6 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } } - if (input_defs.size() < 3) { - LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 3 inputs (data starts and ends) but got " - << input_defs.size(); - return false; - } - - const auto& starts_name = input_defs[1]->Name(); - const auto& ends_name = input_defs[2]->Name(); - if (!Contains(initializers, starts_name) || !Contains(initializers, ends_name)) { - LOGS(logger, VERBOSE) << op_type << " [" << name << "] need starts and ends as initializer."; - return false; - } return true; } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index d26a1ca5ce327..b6631263dfb93 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -64,12 +64,15 @@ void ModelBuilder::PreprocessActivations() { const auto& op_type(node->OpType()); if (op_type == "Clip") { - float minValue, maxValue; - GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_); - emscripten::val options = emscripten::val::object(); - options.set("minValue", minValue); - options.set("maxValue", maxValue); - activation_nodes_.emplace(node->Index(), wnn_builder_.call("clamp", options)); + // Temporarily disable clamp fusion for WebNN GPU as which is not supported yet. + if (wnn_device_type_ == WebnnDeviceType::CPU) { + float minValue, maxValue; + GetClipMinMax(GetInitializerTensors(), *node, minValue, maxValue, logger_); + emscripten::val options = emscripten::val::object(); + options.set("minValue", minValue); + options.set("maxValue", maxValue); + activation_nodes_.emplace(node->Index(), wnn_builder_.call("clamp", options)); + } } else if (op_type == "Elu") { NodeAttrHelper helper(*node); emscripten::val options = emscripten::val::object(); diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 65dc8ddbeaf90..463317a4dafda 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -99,7 +99,9 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { { // Logical CreateLogicalOpBuilder("Equal", op_registrations); CreateLogicalOpBuilder("Greater", op_registrations); + CreateLogicalOpBuilder("GreaterOrEqual", op_registrations); CreateLogicalOpBuilder("Less", op_registrations); + CreateLogicalOpBuilder("LessOrEqual", op_registrations); } { // Max/Min diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 02a3d16b5b64f..4da54aaad3a33 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -17,8 +17,8 @@ namespace onnxruntime { -WebNNExecutionProvider::WebNNExecutionProvider( - const std::string& webnn_device_flags, const std::string& webnn_power_flags) +WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags, + const std::string& webnn_threads_number, const std::string& webnn_power_flags) : IExecutionProvider{onnxruntime::kWebNNExecutionProvider, true} { // Create WebNN context and graph builder. const emscripten::val ml = emscripten::val::global("navigator")["ml"]; @@ -31,6 +31,10 @@ WebNNExecutionProvider::WebNNExecutionProvider( if (webnn_device_flags.compare("cpu") == 0) { preferred_layout_ = DataLayout::NHWC; wnn_device_type_ = webnn::WebnnDeviceType::CPU; + // Set "numThreads" if it's not default 0. + if (webnn_threads_number.compare("0") != 0) { + context_options.set("numThreads", stoi(webnn_threads_number)); + } } else { preferred_layout_ = DataLayout::NCHW; wnn_device_type_ = webnn::WebnnDeviceType::GPU; diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.h b/onnxruntime/core/providers/webnn/webnn_execution_provider.h index f8d9a1c33f6c8..13a475327dc0c 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.h +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.h @@ -18,7 +18,8 @@ class Model; class WebNNExecutionProvider : public IExecutionProvider { public: - WebNNExecutionProvider(const std::string& webnn_device_flags, const std::string& webnn_power_flags); + WebNNExecutionProvider(const std::string& webnn_device_flags, const std::string& webnn_threads_number, + const std::string& webnn_power_flags); virtual ~WebNNExecutionProvider(); std::vector> diff --git a/onnxruntime/core/providers/webnn/webnn_provider_factory.cc b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc index 4d6b04c8e76d8..11acec8b1f354 100644 --- a/onnxruntime/core/providers/webnn/webnn_provider_factory.cc +++ b/onnxruntime/core/providers/webnn/webnn_provider_factory.cc @@ -10,23 +10,26 @@ using namespace onnxruntime; namespace onnxruntime { struct WebNNProviderFactory : IExecutionProviderFactory { - WebNNProviderFactory(const std::string& webnn_device_flags, const std::string& webnn_power_flags) - : webnn_device_flags_(webnn_device_flags), webnn_power_flags_(webnn_power_flags) {} + WebNNProviderFactory(const std::string& webnn_device_flags, const std::string& webnn_threads_number, + const std::string& webnn_power_flags) + : webnn_device_flags_(webnn_device_flags), webnn_threads_number_(webnn_threads_number), webnn_power_flags_(webnn_power_flags) {} ~WebNNProviderFactory() override {} std::unique_ptr CreateProvider() override; std::string webnn_device_flags_; + std::string webnn_threads_number_; std::string webnn_power_flags_; }; std::unique_ptr WebNNProviderFactory::CreateProvider() { - return std::make_unique(webnn_device_flags_, webnn_power_flags_); + return std::make_unique(webnn_device_flags_, webnn_threads_number_, webnn_power_flags_); } std::shared_ptr WebNNProviderFactoryCreator::Create( const ProviderOptions& provider_options) { return std::make_shared(provider_options.at("deviceType"), + provider_options.at("numThreads"), provider_options.at("powerPreference")); } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index ccedc71b9119a..f02d180ab104f 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2025,9 +2025,10 @@ common::Status InferenceSession::ValidateInputsOutputs(gsl::spansecond.tensor_shape.has_value()) { + const auto& opt_shape = iter->second.tensor_shape; + if (opt_shape.has_value() && !opt_shape->GetDims().empty()) { ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(name, input_output_tensor.Shape(), - *iter->second.tensor_shape, input_output_moniker)); + *opt_shape, input_output_moniker)); } } else if (input_output_ml_value.IsSparseTensor()) { #if !defined(DISABLE_SPARSE_TENSORS) @@ -2038,9 +2039,10 @@ common::Status InferenceSession::ValidateInputsOutputs(gsl::spansecond.tensor_shape.has_value()) { + const auto& opt_shape = iter->second.tensor_shape; + if (opt_shape.has_value() && !opt_shape->GetDims().empty()) { ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(name, sparse_tensor.DenseShape(), - *iter->second.tensor_shape, input_output_moniker)); + *opt_shape, input_output_moniker)); } } else if (is_sparse_initializer(name) && expected_type->IsTensorType()) { @@ -2049,9 +2051,10 @@ common::Status InferenceSession::ValidateInputsOutputs(gsl::spansecond.tensor_shape.has_value()) { + const auto& opt_shape = iter->second.tensor_shape; + if (opt_shape.has_value() && !opt_shape->GetDims().empty()) { ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(name, sparse_tensor.DenseShape(), - *iter->second.tensor_shape, input_output_moniker)); + *opt_shape, input_output_moniker)); } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, input_output_moniker, " with name: '", name, @@ -2061,7 +2064,6 @@ common::Status InferenceSession::ValidateInputsOutputs(gsl::spanIsTensorSequenceType() #if !defined(DISABLE_OPTIONAL_TYPE) diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 4649ac35c3647..cb51a0c460d9a 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -104,8 +104,10 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, } else if (strcmp(provider_name, "WEBNN") == 0) { #if defined(USE_WEBNN) std::string deviceType = options->value.config_options.GetConfigOrDefault("deviceType", "cpu"); + std::string numThreads = options->value.config_options.GetConfigOrDefault("numThreads", "0"); std::string powerPreference = options->value.config_options.GetConfigOrDefault("powerPreference", "default"); provider_options["deviceType"] = deviceType; + provider_options["numThreads"] = numThreads; provider_options["powerPreference"] = powerPreference; options->provider_factories.push_back(WebNNProviderFactoryCreator::Create(provider_options)); #else diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index a91ff91010e4b..a9cbef98d9165 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -157,6 +157,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "MemcpyFromHost": self._pass_on_shape_and_type, "MemcpyToHost": self._pass_on_shape_and_type, "Min": self._infer_symbolic_compute_ops, + "MoE": self._pass_on_shape_and_type, "Mul": self._infer_symbolic_compute_ops, "NonMaxSuppression": self._infer_NonMaxSuppression, "NonZero": self._infer_NonZero, diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index c1b241aa1a5ec..d11cb91d98b0c 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -657,7 +657,6 @@ def create_multihead_attention_node( return None graph_input_names = set([node.name for node in self.model.graph().input]) - graph_output_names = set([node.name for node in self.model.graph().output]) mha_node_name = self.model.create_node_name("Attention") # Add initial Q/K/V inputs for MHA @@ -693,12 +692,15 @@ def create_multihead_attention_node( mha_inputs.append("") # Add optional inputs for MHA - if past_k and past_v and past_k in graph_input_names and past_v in graph_input_names: + + if past_k and past_v: mha_inputs.extend([key_padding_mask, add_qk, past_k, past_v]) + elif key_padding_mask or add_qk: + mha_inputs.extend([key_padding_mask, add_qk]) # Add outputs for MHA mha_outputs = [output] - if present_k and present_v and present_k in graph_output_names and present_v in graph_output_names: + if present_k and present_v: mha_outputs.extend([present_k, present_v]) mha_node = helper.make_node( diff --git a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py new file mode 100644 index 0000000000000..6bc681c57444e --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py @@ -0,0 +1,143 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging + +from fusion_attention import AttentionMask, FusionAttention +from onnx_model import OnnxModel + +logger = logging.getLogger(__name__) + + +class FusionConformerAttention(FusionAttention): + """ + Fuse Conformer Attention subgraph into one MultiHeadAttention node. + """ + + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + attention_mask: AttentionMask, + ): + super().__init__(model, hidden_size, num_heads, attention_mask) + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + # SkipLayerNormalization has two inputs, and one of them is the root input for attention. + qkv_nodes = self.model.match_parent_path( + normalize_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 1, 0, 0, 0], + ) + if qkv_nodes is not None: + ( + _, + _, + reshape_qkv, + transpose_qkv, + matmul_qkv, + ) = qkv_nodes + else: + logger.debug("fuse_conformer_attention: failed to match qkv path") + return + + v_nodes = self.model.match_parent_path( + matmul_qkv, + ["Concat", "Transpose", "Reshape", "Add", "MatMul"], + [1, 1, 0, 0, 1], + ) + + add_v = None + if v_nodes is not None: + (concat_v, _, _, add_v, matmul_v) = v_nodes + concat_parent = self.model.get_parent(concat_v, 0, None) + present_v = concat_v.output[0] + past_v = concat_parent.output[0] + else: + logger.debug("fuse_conformer_attention: failed to match v path") + return + + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) + + if qk_nodes is not None: + _, add_qk, matmul_qk = qk_nodes + else: + logger.debug("fuse_conformer_attention: failed to match qk path") + return + + q_nodes = self.model.match_parent_path( + matmul_qk, + ["Div", "Transpose", "Reshape", "Add", "MatMul"], + [0, 0, 0, 0, 1], + ) + if q_nodes is not None: + _, _, reshape_q, add_q, matmul_q = q_nodes + else: + logger.debug("fuse_conformer_attention: failed to match q path") + return + + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Concat", "Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 1, 0, 0, 1], + ) + + matmul_k = None + if k_nodes is not None: + _, concat_k, _, _, add_k, matmul_k = k_nodes + concat_parent = self.model.get_parent(concat_k, 0, None) + past_k = concat_parent.output[0] + present_k = concat_k.output[0] + else: + logger.debug("fuse_conformer_attention: failed to match k path") + return + + attention_last_node = reshape_qkv + num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q) + + if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0: + logger.debug("fuse_conformer_attention: failed to detect num_heads or hidden_size") + return + + new_node = self.create_multihead_attention_node( + matmul_q, + matmul_k, + matmul_v, + add_q, + add_k, + add_v, + num_heads, + hidden_size, + attention_last_node.output[0], + add_qk=add_qk.input[1], + past_k=past_k, + past_v=past_v, + present_k=present_k, + present_v=present_v, + ) + + if new_node is None: + logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed") + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv]) + self.nodes_to_remove.extend(qk_nodes) + + # When using multihead attention, keep MatMul nodes in original graph + if q_nodes[-1].op_type == "MatMul": + q_nodes.pop() + if k_nodes[-1].op_type == "MatMul": + k_nodes.pop() + if v_nodes[-1].op_type == "MatMul": + v_nodes.pop() + + self.nodes_to_remove.extend(k_nodes) + self.nodes_to_remove.extend(v_nodes) + + # Use prune graph to remove mask nodes since they are shared by all attention nodes. + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index 3b344d6dc9342..407c3b80e153f 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -157,14 +157,14 @@ def hook_for_inputs(_, inputs, kwargs): for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs)): if type(value) is torch.Tensor: value.to(model.device) - # Didn't touch past_key_value now, please change it if you want if "use_cache" in key: onnx_inputs[idx] = with_past + out = model(sample_inputs[0], attention_mask=sample_inputs[1], use_cache=with_past) if with_past else out return input_keys, onnx_inputs, out.past_key_values -def move_to_approprate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module: +def move_to_appropriate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module: """ According to the model size, we will upload it to CPU if has no GPU or enough GPU memory, @@ -307,7 +307,7 @@ def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, wit """ model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir) - model = move_to_approprate_device(model, sample_inputs_tp) + model = move_to_appropriate_device(model, sample_inputs_tp) sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device) diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index 0c6f830ed26b0..44dea3cb73b6e 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -135,7 +135,7 @@ $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output l $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --use_gqa ``` -Note: GroupQueryAttention currently runs on Linux for FP16 CUDA and INT4 CUDA models, and it can provide faster inference than MultiHeadAttention, especially for large sequence lengths (e.g. 1024 or larger). For the best performance, you should pre-allocate the KV cache buffers to have size `(batch_size, num_heads, max_sequence_length, head_size)` so that the past KV and present KV caches share the same memory. You also need to bind them with ONNX Runtime's [IO binding](https://onnxruntime.ai/docs/api/python/api_summary.html#iobinding). +Note: GroupQueryAttention currently works with the FP16 CUDA and INT4 CUDA models, and it can provide faster inference than MultiHeadAttention, especially for large sequence lengths (e.g. 1024 or larger). For the best performance, you should pre-allocate the KV cache buffers to have size `(batch_size, num_heads, max_sequence_length, head_size)` so that the past KV and present KV caches share the same memory. You also need to bind them with ONNX Runtime's [IO binding](https://onnxruntime.ai/docs/api/python/api_summary.html#iobinding). Here is an example of how you can bind directly to `torch.tensor` objects: ``` @@ -340,6 +340,18 @@ CUDA_VISIBLE_DEVICES=4 python3 -m models.llama.benchmark \ --device cuda ``` +9. ONNX Runtime, FP16, convert_to_onnx, LLaMA-2 70B shard to 4 GPUs +``` +CUDA_VISIBLE_DEVICES=4,5,6,7 bash benchmark_70b_model.sh 4 \ + --benchmark-type ort-convert-to-onnx \ + --ort-model-path ./llama2-70b-dis/rank_{}_Llama-2-70b-hf_decoder_merged_model_fp16.onnx \ + --model-name meta-llama/Llama-2-70b-hf \ + --precision fp16 \ + --device cuda \ + --warmup-runs 5 \ + --num-runs 100 +``` + You can profile a variant by adding the `--profile` flag and providing one batch size and sequence length combination. ### Benchmark All diff --git a/onnxruntime/python/tools/transformers/models/llama/dist_settings.py b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py index 50b0669d6d83a..72192ce8d8c63 100644 --- a/onnxruntime/python/tools/transformers/models/llama/dist_settings.py +++ b/onnxruntime/python/tools/transformers/models/llama/dist_settings.py @@ -2,8 +2,6 @@ import torch.distributed as dist -comm = None - def init_dist(): if "LOCAL_RANK" in os.environ: @@ -13,10 +11,6 @@ def init_dist(): dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank) elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: - from mpi4py import MPI - - comm = MPI.COMM_WORLD # noqa: F841 - int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", 0)) rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1)) @@ -27,15 +21,28 @@ def init_dist(): pass +def _get_comm(): + try: + from mpi4py import MPI + + comm = MPI.COMM_WORLD + return comm + except ImportError: + return None + + def get_rank(): + comm = _get_comm() return comm.Get_rank() if comm is not None else 0 def get_size(): + comm = _get_comm() return comm.Get_size() if comm is not None else 1 def barrier(): + comm = _get_comm() if comm is not None: comm.Barrier() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index d937e3f4213e0..54af8844d0c6c 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -1,76 +1,97 @@ # Stable Diffusion GPU Optimization -## Overview - -[Stable Diffusion](https://stability.ai/blog/stable-diffusion-announcement) is a text-to-image latent diffusion model for image generation. Explanation of the Stable Diffusion can be found in [Stable Diffusion with Diffusers](https://huggingface.co/blog/stable_diffusion). - -## Optimizations for Stable Diffusion - ONNX Runtime uses the following optimizations to speed up Stable Diffusion in CUDA: * [Flash Attention](https://arxiv.org/abs/2205.14135) for float16 precision. Flash Attention uses tiling to reduce number of GPU memory reads/writes, and improves performance with less memory for long sequence length. The kernel requires GPUs of Compute Capability >= 7.5 (like T4, A100, and RTX 2060~4090). * [Memory Efficient Attention](https://arxiv.org/abs/2112.05682v2) for float32 precision or older GPUs (like V100). We used the fused multi-head attention kernel in CUTLASS, and the kernel was contributed by xFormers. * Channel-last (NHWC) convolution. For NVidia GPU with Tensor Cores support, NHWC tensor layout is recommended for convolution. See [Tensor Layouts In Memory: NCHW vs NHWC](https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#tensor-layout). -* GroupNorm kernel for NHWC tensor layout. +* GroupNorm for NHWC tensor layout, and SkipGroupNorm fusion which fuses GroupNorm with Add bias and residual inputs * SkipLayerNormalization which fuses LayerNormalization with Add bias and residual inputs. * BiasSplitGelu is a fusion of Add bias with SplitGelu activation. * BiasAdd fuses Add bias and residual. * Reduce Transpose nodes by graph transformation. -These optimizations are firstly carried out on CUDA EP. They may not work on other EP. To show the impact of each optimization on latency and GPU memory, we did some experiments: +These optimizations are firstly carried out on CUDA EP. They may not work on other EP. -### Results on RTX 3060 GPU: +## Scripts: -| Optimizations | Average Latency (batch_size=1) | Memory in MB (batch_size=1) | Average Latency (batch_size=8) | Memory in MB (batch_size=8) | -| ---------------------------------------------------------------------------------- | ------------------------------ | --------------------------- | ------------------------------ | --------------------------- | -| Raw FP32 models | 25.6 | 10,667 | OOM | OOM | -| FP16 baseline | 10.2 | 10,709 | OOM | OOM | -| FP16 baseline + FMHA | 6.1 | 7,719 | 39.1 | 10,821 | -| FP16 baseline + FMHA + NhwcConv | 5.5 | 7,656 | 38.8 | 11,615 | -| FP16 baseline + FMHA + NhwcConv + GroupNorm | 5.1 | 6,673 | 35.8 | 10,763 | -| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu | 4.9 | 4,447 | 33.7 | 6,669 | -| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu + Packed QKV | 4.8 | 4,625 | 33.5 | 6,663 | -| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu + Packed QKV + BiasAdd | 4.7 | 4,480 | 33.3 | 6,499 | +| Script | Description | +| ---------------------------------------------- | ----------------------------------------------------------------------------------------- | +| [demo_txt2img_xl.py](./demo_txt2img_xl.py) | Demo of text to image generation using Stable Diffusion XL model. | +| [demo_txt2img.py](./demo_txt2img.py) | Demo of text to image generation using Stable Diffusion models except XL. | +| [optimize_pipeline.py](./optimize_pipeline.py) | Optimize Stable Diffusion ONNX models exported from Huggingface diffusers or optimum | +| [benchmark.py](./benchmark.py) | Benchmark latency and memory of OnnxRuntime, xFormers or PyTorch 2.0 on stable diffusion. | -FP16 baseline contains optimizations available in ONNX Runtime 1.13 including LayerNormalization, SkipLayerNormalization, Gelu and float16 conversion. -Here FMHA means Attention and MultiHeadAttention operators with Flash Attention and Memory Efficient Attention kernels but inputs are not packed. Packed QKV means the inputs are packed. +## Run demo with docker -The last two optimizations (Packed QKV and BiasAdd) are only available in nightly package. Compared to 1.14.1, nightly package has slight improvement in performance. +#### Clone the onnxruntime repository +``` +git clone https://github.com/microsoft/onnxruntime +cd onnxruntime +``` -### Results on MI250X with 1 GCD +#### Launch NVIDIA pytorch container -With runtime tuning enabled, we get following performance number on one GCD of a MI250X GPU: +Install nvidia-docker using [these instructions](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker). -| Optimizations | Average Latency (batch_size=1) | Memory in MB (batch_size=1) | Average Latency (batch_size=8) | Memory in MB (batch_size=8) | -| --------------------------------------------------------------------- | ------------------------------ | --------------------------- | ------------------------------ | --------------------------- | -| Raw FP32 models | 6.7 | 17,319 | 36.4 * | 33,787 | -| FP16 baseline | 4.1 | 8,945 | 24.0 * | 34,493 | -| FP16 baseline + FMHA | 2.6 | 4,886 | 15.0 | 10,146 | -| FP16 baseline + FMHA + NhwcConv | 2.4 | 4,952 | 14.8 | 9,632 | -| FP16 baseline + FMHA + NhwcConv + GroupNorm | 2.3 | 4,906 | 13.6 | 9,774 | -| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu | 2.2 | 4,910 | 12.5 | 9,646 | -| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu + BiasAdd | 2.2 | 4,910 | 12.5 | 9,778 | +``` +docker run --rm -it --gpus all -v $PWD:/workspace nvcr.io/nvidia/pytorch:23.10-py3 /bin/bash +``` -The entries marked with `*` produce suspicious output images. The might be numerical stability or correctness issue for the pipeline. The performance number is for reference only. +#### Build onnxruntime from source +After launching the docker, you can build and install onnxruntime-gpu wheel like the following. +``` +export CUDACXX=/usr/local/cuda-12.2/bin/nvcc +git config --global --add safe.directory '*' +sh build.sh --config Release --build_shared_lib --parallel --use_cuda --cuda_version 12.2 \ + --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/lib/x86_64-linux-gnu/ --build_wheel --skip_tests \ + --use_tensorrt --tensorrt_home /usr/src/tensorrt \ + --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF \ + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 \ + --allow_running_as_root +python3 -m pip install --upgrade pip +python3 -m pip install build/Linux/Release/dist/onnxruntime_gpu-1.17.0-cp310-cp310-linux_x86_64.whl --force-reinstall +``` -## Scripts: +If the GPU is not A100, change `CMAKE_CUDA_ARCHITECTURES=80` in the command line according to the GPU compute capacity. -| Script | Description | -| ---------------------------------------------- | ----------------------------------------------------------------------------------------- | -| [optimize_pipeline.py](./optimize_pipeline.py) | Optimize Stable Diffusion ONNX models | -| [benchmark.py](./benchmark.py) | Benchmark latency and memory of OnnxRuntime, xFormers or PyTorch 2.0 on stable diffusion. | +#### Install required packages +``` +cd /workspace/onnxruntime/python/tools/transformers/models/stable_diffusion +python3 -m pip install -r requirements-cuda12.txt +python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +``` -In below example, we run the scripts in source code directory. You can get source code like the following: +### Run Demo +You can review the usage of supported pipelines like the following: ``` -git clone https://github.com/microsoft/onnxruntime -cd onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion +python3 demo_txt2img.py --help +python3 demo_txt2img_xl.py --help ``` -## Example of Stable Diffusion 1.5 +For example: +`--engine {ORT_CUDA,ORT_TRT,TRT}` can be used to choose different backend engines including CUDA or TensorRT execution provider of ONNX Runtime, or TensorRT. +`--work-dir WORK_DIR` can be used to load or save models under the given directory. You can download the [optimized ONNX models of Stable Diffusion XL 1.0](https://huggingface.co/tlwu/stable-diffusion-xl-1.0-onnxruntime#usage-example) to save time in running the XL demo. + +#### Generate an image guided by a text prompt +```python3 demo_txt2img.py "astronaut riding a horse on mars"``` + +#### Generate an image with Stable Diffusion XL guided by a text prompt +```python3 demo_txt2img_xl.py "starry night over Golden Gate Bridge by van gogh"``` -Below is an example to optimize Stable Diffusion 1.5 in Linux. For Windows OS, please change the format of path to be like `.\sd` instead of `./sd`. +If you do not provide prompt, the script will generate different image sizes for a list of prompts for demonstration. + +#### Generate an image with SDXL LCM guided by a text prompt +```python3 demo_txt2img_xl.py --lcm --disable-refiner "an astronaut riding a rainbow unicorn, cinematic, dramatic"``` + +## Optimize Stable Diffusion ONNX models for Hugging Face Diffusers or Optimum + +If you are able to run the above demo with docker, you can use the docker and skip the following setup and fast forward to [Export ONNX pipeline](#export-onnx-pipeline). + +Below setup does not use docker. We'll use the environment to optimize ONNX models of Stable Diffusion exported by huggingface diffusers or optimum. +For Windows OS, please change the format of path to be like `.\sd` instead of `./sd`. It is recommended to create a Conda environment with Python 3.10 for the following setup: ``` @@ -78,7 +99,7 @@ conda create -n py310 python=3.10 conda activate py310 ``` -### Setup Environment (CUDA) +### Setup Environment (CUDA) without docker First, we need install CUDA 11.8 or 12.1, [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html) 8.5 or above, and [TensorRT 8.6.1](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html) in the machine. @@ -86,12 +107,19 @@ First, we need install CUDA 11.8 or 12.1, [cuDNN](https://docs.nvidia.com/deeple In the Conda environment, install PyTorch 2.1 or above, and other required packages like the following: ``` -pip install torch --index-url https://download.pytorch.org/whl/nightly/cu118 +pip install torch --index-url https://download.pytorch.org/whl/cu118 pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com pip install -r requirements-cuda11.txt ``` -We cannot directly `pip install tensorrt` for CUDA 11. Follow https://github.com/NVIDIA/TensorRT/issues/2773 to install TensorRT for CUDA 11 in Linux. For Windows, pip install the tensorrt wheel in the downloaded TensorRT zip file instead. +For Windows, install nvtx like the following: +``` +conda install -c conda-forge nvtx +``` + +We cannot directly `pip install tensorrt` for CUDA 11. Follow https://github.com/NVIDIA/TensorRT/issues/2773 to install TensorRT for CUDA 11 in Linux. + +For Windows, pip install the tensorrt wheel in the downloaded TensorRT zip file instead. Like `pip install tensorrt-8.6.1.6.windows10.x86_64.cuda-11.8\tensorrt-8.6.1.6\python\tensorrt-8.6.1-cp310-none-win_amd64.whl`. #### CUDA 12.*: The official package of onnxruntime-gpu 1.16.* is built for CUDA 11.8. To use CUDA 12.*, you will need [build onnxruntime from source](https://onnxruntime.ai/docs/build/inferencing.html). @@ -99,6 +127,7 @@ The official package of onnxruntime-gpu 1.16.* is built for CUDA 11.8. To use CU ``` git clone --recursive https://github.com/Microsoft/onnxruntime.git cd onnxruntime +pip install cmake pip install -r requirements-dev.txt ``` Follow [example script for A100 in Ubuntu](https://github.com/microsoft/onnxruntime/blob/26a7b63716e3125bfe35fe3663ba10d2d7322628/build_release.sh) @@ -106,7 +135,7 @@ or [example script for RTX 4090 in Windows](https://github.com/microsoft/onnxrun Then install other python packages like the following: ``` -pip install torch --index-url https://download.pytorch.org/whl/nightly/cu121 +pip install torch --index-url https://download.pytorch.org/whl/cu121 pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com pip install -r requirements-cuda12.txt ``` @@ -182,7 +211,13 @@ Example to optimize the exported float32 ONNX models, and save to float16 models python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i ./sd_v1_5/fp32 -o ./sd_v1_5/fp16 --float16 ``` -For SDXL model, it is recommended to use a machine with 32 GB or more memory to optimize. +In all examples below, we run the scripts in source code directory. You can get source code like the following: +``` +git clone https://github.com/microsoft/onnxruntime +cd onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion +``` + +For SDXL model, it is recommended to use a machine with 48 GB or more memory to optimize. ``` python optimize_pipeline.py -i ./sd_xl_base_onnx -o ./sd_xl_base_fp16 --float16 ``` @@ -265,6 +300,44 @@ python benchmark.py -e tensorrt --height 1024 --width 1024 -s 30 -b 1 -v xl-1.0 python benchmark.py -e onnxruntime -r tensorrt --height 1024 --width 1024 -s 30 -b 1 -v xl-1.0 --enable_cuda_graph ``` +### Results on RTX 3060 GPU: + +To show the impact of each optimization on latency and GPU memory, we did some experiments: + +| Optimizations | Average Latency (batch_size=1) | Memory in MB (batch_size=1) | Average Latency (batch_size=8) | Memory in MB (batch_size=8) | +| ---------------------------------------------------------------------------------- | ------------------------------ | --------------------------- | ------------------------------ | --------------------------- | +| Raw FP32 models | 25.6 | 10,667 | OOM | OOM | +| FP16 baseline | 10.2 | 10,709 | OOM | OOM | +| FP16 baseline + FMHA | 6.1 | 7,719 | 39.1 | 10,821 | +| FP16 baseline + FMHA + NhwcConv | 5.5 | 7,656 | 38.8 | 11,615 | +| FP16 baseline + FMHA + NhwcConv + GroupNorm | 5.1 | 6,673 | 35.8 | 10,763 | +| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu | 4.9 | 4,447 | 33.7 | 6,669 | +| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu + Packed QKV | 4.8 | 4,625 | 33.5 | 6,663 | +| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu + Packed QKV + BiasAdd | 4.7 | 4,480 | 33.3 | 6,499 | + +FP16 baseline contains optimizations available in ONNX Runtime 1.13 including LayerNormalization, SkipLayerNormalization, Gelu and float16 conversion. + +Here FMHA means Attention and MultiHeadAttention operators with Flash Attention and Memory Efficient Attention kernels but inputs are not packed. Packed QKV means the inputs are packed. + +The last two optimizations (Packed QKV and BiasAdd) are only available in nightly package. Compared to 1.14.1, nightly package has slight improvement in performance. + +### Results on MI250X with 1 GCD + +With runtime tuning enabled, we get following performance number on one GCD of a MI250X GPU: + +| Optimizations | Average Latency (batch_size=1) | Memory in MB (batch_size=1) | Average Latency (batch_size=8) | Memory in MB (batch_size=8) | +| --------------------------------------------------------------------- | ------------------------------ | --------------------------- | ------------------------------ | --------------------------- | +| Raw FP32 models | 6.7 | 17,319 | 36.4 * | 33,787 | +| FP16 baseline | 4.1 | 8,945 | 24.0 * | 34,493 | +| FP16 baseline + FMHA | 2.6 | 4,886 | 15.0 | 10,146 | +| FP16 baseline + FMHA + NhwcConv | 2.4 | 4,952 | 14.8 | 9,632 | +| FP16 baseline + FMHA + NhwcConv + GroupNorm | 2.3 | 4,906 | 13.6 | 9,774 | +| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu | 2.2 | 4,910 | 12.5 | 9,646 | +| FP16 baseline + FMHA + NhwcConv + GroupNorm + BiasSplitGelu + BiasAdd | 2.2 | 4,910 | 12.5 | 9,778 | + +The entries marked with `*` produce suspicious output images. The might be numerical stability or correctness issue for the pipeline. The performance number is for reference only. + + ### Example Benchmark output Common settings for below test results: @@ -400,7 +473,8 @@ Results are from Standard_NC4as_T4_v3 Azure virtual machine: ### Credits -Some CUDA kernels (Flash Attention, GroupNorm, SplitGelu and BiasAdd etc.) were originally implemented in [TensorRT](https://github.com/nviDIA/TensorRT) by Nvidia. +Some CUDA kernels (TensorRT Fused Attention, GroupNorm, SplitGelu and BiasAdd etc.) and demo diffusion were originally implemented in [TensorRT](https://github.com/nviDIA/TensorRT) by Nvidia. +We use [Flash Attention v2](https://github.com/Dao-AILab/flash-attention) in Linux. We use Memory efficient attention from [CUTLASS](https://github.com/NVIDIA/cutlass). The kernels were developed by Meta xFormers. The ONNX export script and pipeline for stable diffusion was developed by Huggingface [diffusers](https://github.com/huggingface/diffusers) library. @@ -408,10 +482,8 @@ Most ROCm kernel optimizations are from [composable kernel](https://github.com/R Some kernels are enabled by MIOpen. We hereby thank for the AMD developers' collaboration. ### Future Works - -There are other optimizations might improve the performance or reduce memory footprint: -* Export the whole pipeline into a single ONNX model. Currently, there are multiple ONNX models (CLIP, VAE and U-Net etc). Each model uses separated thread pool and memory allocator. Combine them into one model could share thread pool and memory allocator. The end result is more efficient and less memory footprint. -* For Stable Diffusion 2.1, we disable TensorRT flash attention kernel and use only memory efficient attention. It is possible to add flash attention in Windows to improve performance. -* Reduce GPU memory footprint by actively deleting buffers for intermediate results. -* Safety Checker Optimization -* Leverage FP8 in latest GPU +* Update demo to support inpainting, LoRA Weights and Control Net. +* Support flash attention in Windows. +* Integration with UI. +* Optimization for H100 GPU. +* Export the whole pipeline into a single ONNX model. This senario is mainly for mobile device. diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py index fb051ac1ed3b4..b3056cc47c647 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -22,7 +22,7 @@ import coloredlogs from cuda import cudart -from demo_utils import init_pipeline, parse_arguments, repeat_prompt +from demo_utils import get_metadata, init_pipeline, parse_arguments, repeat_prompt from diffusion_models import PipelineInfo from engine_builder import EngineType, get_engine_type from pipeline_txt2img import Txt2ImgPipeline @@ -53,8 +53,31 @@ f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4" ) - pipeline_info = PipelineInfo(args.version) - pipeline = init_pipeline(Txt2ImgPipeline, pipeline_info, engine_type, args, max_batch_size, batch_size) + # For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size. + # Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance. + # This range can cover common used shape of landscape 512x768, portrait 768x512, or square 512x512 and 768x768. + min_image_size = 512 if args.engine != "ORT_CUDA" else 256 + max_image_size = 768 if args.engine != "ORT_CUDA" else 1024 + pipeline_info = PipelineInfo(args.version, min_image_size=min_image_size, max_image_size=max_image_size) + + # Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to + # optimize the shape used most frequently. We can let user config it when we develop a UI plugin. + # In this demo, we optimize batch size 1 and image size 512x512 (or 768x768 for SD 2.0/2.1) for dynamic engine. + # This is mainly for benchmark purpose to simulate the case that we have no knowledge of user's preference. + opt_batch_size = 1 if args.build_dynamic_batch else batch_size + opt_image_height = pipeline_info.default_image_size() if args.build_dynamic_shape else args.height + opt_image_width = pipeline_info.default_image_size() if args.build_dynamic_shape else args.width + + pipeline = init_pipeline( + Txt2ImgPipeline, + pipeline_info, + engine_type, + args, + max_batch_size, + opt_batch_size, + opt_image_height, + opt_image_width, + ) if engine_type == EngineType.TRT: max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory()) @@ -81,17 +104,25 @@ def run_inference(warmup=False): if not args.disable_cuda_graph: # inference once to get cuda graph - _image, _latency = run_inference(warmup=True) + _, _ = run_inference(warmup=True) print("[I] Warming up ..") for _ in range(args.num_warmup_runs): - _image, _latency = run_inference(warmup=True) + _, _ = run_inference(warmup=True) print("[I] Running StableDiffusion pipeline") if args.nvtx_profile: cudart.cudaProfilerStart() - _image, _latency = run_inference(warmup=False) + images, perf_data = run_inference(warmup=False) if args.nvtx_profile: cudart.cudaProfilerStop() + metadata = get_metadata(args, False) + metadata.update(pipeline.metadata()) + if perf_data: + metadata.update(perf_data) + metadata["images"] = len(images) + print(metadata) + pipeline.save_images(images, prompt, negative_prompt, metadata) + pipeline.teardown() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py index 0b529875a2fe7..7ff1794a68f8c 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -22,7 +22,7 @@ import coloredlogs from cuda import cudart -from demo_utils import init_pipeline, parse_arguments, repeat_prompt +from demo_utils import get_metadata, init_pipeline, parse_arguments, repeat_prompt from diffusion_models import PipelineInfo from engine_builder import EngineType, get_engine_type from pipeline_img2img_xl import Img2ImgXLPipeline @@ -46,18 +46,62 @@ def load_pipelines(args, batch_size): if batch_size > max_batch_size: raise ValueError(f"Batch size {batch_size} is larger than allowed {max_batch_size}.") - # No VAE decoder in base when it outputs latent instead of image. - base_info = PipelineInfo(args.version, use_vae=False, min_image_size=640, max_image_size=1536) - base = init_pipeline(Txt2ImgXLPipeline, base_info, engine_type, args, max_batch_size, batch_size) + # For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size. + # Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance. + # This range can cover most frequent shape of landscape (832x1216), portrait (1216x832) or square (1024x1024). + min_image_size = 832 if args.engine != "ORT_CUDA" else 512 + max_image_size = 1216 if args.engine != "ORT_CUDA" else 2048 - refiner_info = PipelineInfo(args.version, is_refiner=True, min_image_size=640, max_image_size=1536) - refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type, args, max_batch_size, batch_size) + # No VAE decoder in base when it outputs latent instead of image. + base_info = PipelineInfo( + args.version, + use_vae=args.disable_refiner, + min_image_size=min_image_size, + max_image_size=max_image_size, + use_lcm=args.lcm, + ) + + # Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to + # optimize the shape used most frequently. We can let user config it when we develop a UI plugin. + # In this demo, we optimize batch size 1 and image size 1024x1024 for SD XL dynamic engine. + # This is mainly for benchmark purpose to simulate the case that we have no knowledge of user's preference. + opt_batch_size = 1 if args.build_dynamic_batch else batch_size + opt_image_height = base_info.default_image_size() if args.build_dynamic_shape else args.height + opt_image_width = base_info.default_image_size() if args.build_dynamic_shape else args.width + + base = init_pipeline( + Txt2ImgXLPipeline, + base_info, + engine_type, + args, + max_batch_size, + opt_batch_size, + opt_image_height, + opt_image_width, + ) + + refiner = None + if not args.disable_refiner: + refiner_info = PipelineInfo( + args.version, is_refiner=True, min_image_size=min_image_size, max_image_size=max_image_size + ) + refiner = init_pipeline( + Img2ImgXLPipeline, + refiner_info, + engine_type, + args, + max_batch_size, + opt_batch_size, + opt_image_height, + opt_image_width, + ) if engine_type == EngineType.TRT: - max_device_memory = max(base.backend.max_device_memory(), refiner.backend.max_device_memory()) + max_device_memory = max(base.backend.max_device_memory(), (refiner or base).backend.max_device_memory()) _, shared_device_memory = cudart.cudaMalloc(max_device_memory) base.backend.activate_engines(shared_device_memory) - refiner.backend.activate_engines(shared_device_memory) + if refiner: + refiner.backend.activate_engines(shared_device_memory) if engine_type == EngineType.ORT_CUDA: enable_vae_slicing = args.enable_vae_slicing @@ -65,7 +109,7 @@ def load_pipelines(args, batch_size): print("Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4.") enable_vae_slicing = True if enable_vae_slicing: - refiner.backend.enable_vae_slicing() + (refiner or base).backend.enable_vae_slicing() return base, refiner @@ -74,10 +118,11 @@ def run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False image_width = args.width batch_size = len(prompt) base.load_resources(image_height, image_width, batch_size) - refiner.load_resources(image_height, image_width, batch_size) + if refiner: + refiner.load_resources(image_height, image_width, batch_size) def run_base_and_refiner(warmup=False): - images, time_base = base.run( + images, base_perf = base.run( prompt, negative_prompt, image_height, @@ -86,22 +131,34 @@ def run_base_and_refiner(warmup=False): denoising_steps=args.denoising_steps, guidance=args.guidance, seed=args.seed, - return_type="latent", + return_type="latent" if refiner else "image", ) + if refiner is None: + return images, base_perf - images, time_refiner = refiner.run( + # Use same seed in base and refiner. + seed = base.get_current_seed() + + images, refiner_perf = refiner.run( prompt, negative_prompt, images, image_height, image_width, warmup=warmup, - denoising_steps=args.denoising_steps, - guidance=args.guidance, - seed=args.seed, + denoising_steps=args.refiner_steps, + strength=args.strength, + guidance=args.refiner_guidance, + seed=seed, ) - return images, time_base + time_refiner + perf_data = None + if base_perf and refiner_perf: + perf_data = {"latency": base_perf["latency"] + refiner_perf["latency"]} + perf_data.update({"base." + key: val for key, val in base_perf.items()}) + perf_data.update({"refiner." + key: val for key, val in refiner_perf.items()}) + + return images, perf_data if not args.disable_cuda_graph: # inference once to get cuda graph @@ -118,13 +175,24 @@ def run_base_and_refiner(warmup=False): print("[I] Running StableDiffusion XL pipeline") if args.nvtx_profile: cudart.cudaProfilerStart() - _, latency = run_base_and_refiner(warmup=False) + images, perf_data = run_base_and_refiner(warmup=False) if args.nvtx_profile: cudart.cudaProfilerStop() - print("|------------|--------------|") - print("| {:^10} | {:>9.2f} ms |".format("e2e", latency)) - print("|------------|--------------|") + if refiner: + print("|------------|--------------|") + print("| {:^10} | {:>9.2f} ms |".format("e2e", perf_data["latency"])) + print("|------------|--------------|") + + metadata = get_metadata(args, True) + metadata.update({"base." + key: val for key, val in base.metadata().items()}) + if refiner: + metadata.update({"refiner." + key: val for key, val in refiner.metadata().items()}) + if perf_data: + metadata.update(perf_data) + metadata["images"] = len(images) + print(metadata) + (refiner or base).save_images(images, prompt, negative_prompt, metadata) def run_demo(args): @@ -135,59 +203,92 @@ def run_demo(args): base, refiner = load_pipelines(args, batch_size) run_pipelines(args, base, refiner, prompt, negative_prompt) base.teardown() - refiner.teardown() + if refiner: + refiner.teardown() def run_dynamic_shape_demo(args): - """Run demo of generating images with different size with list of prompts with ORT CUDA provider.""" + """Run demo of generating images with different settings with ORT CUDA provider.""" args.engine = "ORT_CUDA" - args.scheduler = "UniPC" - args.denoising_steps = 8 args.disable_cuda_graph = True + if args.lcm: + args.disable_refiner = True + base, refiner = load_pipelines(args, 1) - batch_size = args.repeat_prompt - base, refiner = load_pipelines(args, batch_size) + prompts = [ + "starry night over Golden Gate Bridge by van gogh", + "beautiful photograph of Mt. Fuji during cherry blossom", + "little cute gremlin sitting on a bed, cinematic", + "cute grey cat with blue eyes, wearing a bowtie, acrylic painting", + "beautiful Renaissance Revival Estate, Hobbit-House, detailed painting, warm colors, 8k, trending on Artstation", + "blue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realistic", + "An astronaut riding a rainbow unicorn, cinematic, dramatic", + "close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm", + ] - image_sizes = [ - (1024, 1024), - (1152, 896), - (896, 1152), - (1216, 832), - (832, 1216), - (1344, 768), - (768, 1344), - (1536, 640), - (640, 1536), + # refiner, batch size, height, width, scheduler, steps, prompt, seed, guidance, refiner scheduler, refiner steps, refiner strength + configs = [ + (1, 832, 1216, "UniPC", 8, prompts[0], None, 5.0, "UniPC", 10, 0.3), + (1, 1024, 1024, "DDIM", 24, prompts[1], None, 5.0, "DDIM", 30, 0.3), + (1, 1216, 832, "UniPC", 16, prompts[2], None, 5.0, "UniPC", 10, 0.3), + (1, 1344, 768, "DDIM", 24, prompts[3], None, 5.0, "UniPC", 20, 0.3), + (2, 640, 1536, "UniPC", 16, prompts[4], 4312973633252712, 5.0, "UniPC", 10, 0.3), + (2, 1152, 896, "DDIM", 24, prompts[5], 1964684802882906, 5.0, "UniPC", 20, 0.3), ] - # Warm up the pipelines. This only need once before serving. + # In testing LCM, refiner is disabled so the settings of refiner is not used. + if args.lcm: + configs = [ + (1, 1024, 1024, "LCM", 8, prompts[6], None, 1.0, "UniPC", 20, 0.3), + (1, 1216, 832, "LCM", 6, prompts[7], 1337, 1.0, "UniPC", 20, 0.3), + ] + + # Warm up each combination of (batch size, height, width) once before serving. args.prompt = ["warm up"] - args.num_warmup_runs = 3 - prompt, negative_prompt = repeat_prompt(args) - for height, width in image_sizes: + args.num_warmup_runs = 1 + for batch_size, height, width, _, _, _, _, _, _, _, _ in configs: + args.batch_size = batch_size args.height = height args.width = width - print(f"\nWarm up pipelines for Batch_size={batch_size}, Height={height}, Width={width}") + print(f"\nWarm up batch_size={batch_size}, height={height}, width={width}") + prompt, negative_prompt = repeat_prompt(args) run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=True) # Run pipeline on a list of prompts. - prompts = [ - "starry night over Golden Gate Bridge by van gogh", - "little cute gremlin sitting on a bed, cinematic", - ] args.num_warmup_runs = 0 - for example_prompt in prompts: + for ( + batch_size, + height, + width, + scheduler, + steps, + example_prompt, + seed, + guidance, + refiner_scheduler, + refiner_steps, + strength, + ) in configs: args.prompt = [example_prompt] + args.batch_size = batch_size + args.height = height + args.width = width + args.scheduler = scheduler + args.denoising_steps = steps + args.seed = seed + args.guidance = guidance + args.refiner_scheduler = refiner_scheduler + args.refiner_steps = refiner_steps + args.strength = strength + base.set_scheduler(scheduler) + if refiner: + refiner.set_scheduler(refiner_scheduler) prompt, negative_prompt = repeat_prompt(args) - - for height, width in image_sizes: - args.height = height - args.width = width - print(f"\nBatch_size={batch_size}, Height={height}, Width={width}, Prompt={example_prompt}") - run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False) + run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False) base.teardown() - refiner.teardown() + if refiner: + refiner.teardown() if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py index d0e4e3adefbc3..70b4f34fdd988 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -21,6 +21,7 @@ # -------------------------------------------------------------------------- import argparse +from typing import Any, Dict import torch from diffusion_models import PipelineInfo @@ -68,8 +69,8 @@ def parse_arguments(is_xl: bool, description: str): "--scheduler", type=str, default="DDIM", - choices=["DDIM", "UniPC"] if is_xl else ["DDIM", "EulerA", "UniPC"], - help="Scheduler for diffusion process", + choices=["DDIM", "UniPC", "LCM"] if is_xl else ["DDIM", "EulerA", "UniPC"], + help="Scheduler for diffusion process" + " of base" if is_xl else "", ) parser.add_argument( @@ -84,7 +85,7 @@ def parse_arguments(is_xl: bool, description: str): "--negative-prompt", nargs="*", default=[""], help="Optional negative prompt(s) to guide the image generation." ) parser.add_argument( - "--repeat-prompt", + "--batch-size", type=int, default=1, choices=[1, 2, 4, 8, 16], @@ -105,6 +106,42 @@ def parse_arguments(is_xl: bool, description: str): help="Higher guidance scale encourages to generate images that are closely linked to the text prompt.", ) + if is_xl: + parser.add_argument( + "--lcm", + action="store_true", + help="Use fine-tuned latent consistency model to replace the UNet in base.", + ) + + parser.add_argument( + "--refiner-scheduler", + type=str, + default="DDIM", + choices=["DDIM", "UniPC"], + help="Scheduler for diffusion process of refiner.", + ) + + parser.add_argument( + "--refiner-guidance", + type=float, + default=5.0, + help="Guidance scale used in refiner.", + ) + + parser.add_argument( + "--refiner-steps", + type=int, + default=30, + help="Number of denoising steps in refiner. Note that actual refiner steps is refiner_steps * strength.", + ) + + parser.add_argument( + "--strength", + type=float, + default=0.3, + help="A value between 0 and 1. The higher the value less the final image similar to the seed image.", + ) + # ONNX export parser.add_argument( "--onnx-opset", @@ -145,6 +182,10 @@ def parse_arguments(is_xl: bool, description: str): parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.") parser.add_argument("--disable-cuda-graph", action="store_true", help="Disable cuda graph.") + parser.add_argument( + "--disable-refiner", action="store_true", help="Disable refiner and only run base for XL pipeline." + ) + group = parser.add_argument_group("Options for ORT_CUDA engine only") group.add_argument("--enable-vae-slicing", action="store_true", help="True will feed only one image to VAE once.") @@ -174,9 +215,9 @@ def parse_arguments(is_xl: bool, description: str): ) # Validate image dimensions - if args.height % 8 != 0 or args.width % 8 != 0: + if args.height % 64 != 0 or args.width % 64 != 0: raise ValueError( - f"Image height and width have to be divisible by 8 but specified as: {args.height} and {args.width}." + f"Image height and width have to be divisible by 64 but specified as: {args.height} and {args.width}." ) if (args.build_dynamic_batch or args.build_dynamic_shape) and not args.disable_cuda_graph: @@ -186,15 +227,56 @@ def parse_arguments(is_xl: bool, description: str): if args.onnx_opset is None: args.onnx_opset = 14 if args.engine == "ORT_CUDA" else 17 + if is_xl: + if args.lcm: + if args.guidance > 1.0: + print("[I] Use --guidance=1.0 for base since LCM is used.") + args.guidance = 1.0 + if args.scheduler != "LCM": + print("[I] Use --scheduler=LCM for base since LCM is used.") + args.scheduler = "LCM" + if args.denoising_steps > 16: + print("[I] Use --denoising_steps=8 (no more than 16) for base since LCM is used.") + args.denoising_steps = 8 + assert args.strength > 0.0 and args.strength < 1.0 + print(args) return args +def get_metadata(args, is_xl: bool = False) -> Dict[str, Any]: + metadata = { + "args.prompt": args.prompt, + "args.negative_prompt": args.negative_prompt, + "args.batch_size": args.batch_size, + "height": args.height, + "width": args.width, + "cuda_graph": not args.disable_cuda_graph, + "vae_slicing": args.enable_vae_slicing, + "engine": args.engine, + } + + if is_xl and not args.disable_refiner: + metadata["base.scheduler"] = args.scheduler + metadata["base.denoising_steps"] = args.denoising_steps + metadata["base.guidance"] = args.guidance + metadata["refiner.strength"] = args.strength + metadata["refiner.scheduler"] = args.refiner_scheduler + metadata["refiner.denoising_steps"] = args.refiner_steps + metadata["refiner.guidance"] = args.refiner_guidance + else: + metadata["scheduler"] = args.scheduler + metadata["denoising_steps"] = args.denoising_steps + metadata["guidance"] = args.guidance + + return metadata + + def repeat_prompt(args): if not isinstance(args.prompt, list): raise ValueError(f"`prompt` must be of type `str` or `str` list, but is {type(args.prompt)}") - prompt = args.prompt * args.repeat_prompt + prompt = args.prompt * args.batch_size if not isinstance(args.negative_prompt, list): raise ValueError( @@ -209,7 +291,9 @@ def repeat_prompt(args): return prompt, negative_prompt -def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_size, batch_size): +def init_pipeline( + pipeline_class, pipeline_info, engine_type, args, max_batch_size, opt_batch_size, opt_image_height, opt_image_width +): onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( work_dir=args.work_dir, pipeline_info=pipeline_info, engine_type=engine_type ) @@ -217,7 +301,7 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si # Initialize demo pipeline = pipeline_class( pipeline_info, - scheduler=args.scheduler, + scheduler=args.refiner_scheduler if pipeline_info.is_xl_refiner() else args.scheduler, output_dir=output_dir, hf_token=args.hf_token, verbose=False, @@ -234,9 +318,6 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si engine_dir=engine_dir, framework_model_dir=framework_model_dir, onnx_dir=onnx_dir, - opt_image_height=args.height, - opt_image_width=args.height, - opt_batch_size=batch_size, force_engine_rebuild=args.force_engine_build, device_id=torch.cuda.current_device(), ) @@ -247,14 +328,15 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si framework_model_dir, onnx_dir, args.onnx_opset, - opt_image_height=args.height, - opt_image_width=args.height, - opt_batch_size=batch_size, + opt_image_height=opt_image_height, + opt_image_width=opt_image_width, + opt_batch_size=opt_batch_size, force_engine_rebuild=args.force_engine_build, static_batch=not args.build_dynamic_batch, static_image_shape=not args.build_dynamic_shape, max_workspace_size=0, device_id=torch.cuda.current_device(), + timing_cache=timing_cache, ) elif engine_type == EngineType.TRT: # Load TensorRT engines and pytorch modules @@ -263,9 +345,9 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si framework_model_dir, onnx_dir, args.onnx_opset, - opt_batch_size=batch_size, - opt_image_height=args.height, - opt_image_width=args.height, + opt_batch_size=opt_batch_size, + opt_image_height=opt_image_height, + opt_image_width=opt_image_width, force_export=args.force_onnx_export, force_optimize=args.force_onnx_optimize, force_build=args.force_engine_build, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py index 3dcde7f6db398..8206bee753859 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -90,6 +90,8 @@ def __init__( use_vae=False, min_image_size=256, max_image_size=1024, + use_fp16_vae=True, + use_lcm=False, ): self.version = version self._is_inpaint = is_inpaint @@ -97,7 +99,10 @@ def __init__( self._use_vae = use_vae self._min_image_size = min_image_size self._max_image_size = max_image_size + self._use_fp16_vae = use_fp16_vae + self._use_lcm = use_lcm if is_refiner: + assert not use_lcm assert self.is_xl() def is_inpaint(self) -> bool: @@ -127,6 +132,16 @@ def stages(self) -> List[str]: def vae_scaling_factor(self) -> float: return 0.13025 if self.is_xl() else 0.18215 + def vae_torch_fallback(self) -> bool: + return self.is_xl() and not self._use_fp16_vae + + def custom_fp16_vae(self) -> Optional[str]: + # For SD XL, use a VAE that fine-tuned to run in fp16 precision without generating NaNs + return "madebyollin/sdxl-vae-fp16-fix" if self._use_fp16_vae and self.is_xl() else None + + def custom_unet(self) -> Optional[str]: + return "latent-consistency/lcm-sdxl" if self._use_lcm and self.is_xl_base() else None + @staticmethod def supported_versions(is_xl: bool): return ["xl-1.0"] if is_xl else ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base"] @@ -202,6 +217,13 @@ def min_image_size(self): def max_image_size(self): return self._max_image_size + def default_image_size(self): + if self.is_xl(): + return 1024 + if self.version in ("2.0", "2.1"): + return 768 + return 512 + class BaseModel: def __init__( @@ -714,8 +736,22 @@ def __init__( self.unet_dim = unet_dim self.time_dim = time_dim + self.custom_unet = pipeline_info.custom_unet() + self.do_classifier_free_guidance = not (self.custom_unet and "lcm" in self.custom_unet) + self.batch_multiplier = 2 if self.do_classifier_free_guidance else 1 + def load_model(self, framework_model_dir, hf_token, subfolder="unet"): options = {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 else {} + + if self.custom_unet: + model_dir = os.path.join(framework_model_dir, self.custom_unet, subfolder) + if not os.path.exists(model_dir): + unet = UNet2DConditionModel.from_pretrained(self.custom_unet, **options) + unet.save_pretrained(model_dir) + else: + unet = UNet2DConditionModel.from_pretrained(model_dir, **options) + return unet.to(self.device) + return self.from_pretrained(UNet2DConditionModel, framework_model_dir, hf_token, subfolder, **options) def get_input_names(self): @@ -725,12 +761,20 @@ def get_output_names(self): return ["latent"] def get_dynamic_axes(self): + if self.do_classifier_free_guidance: + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + "text_embeds": {0: "2B"}, + "time_ids": {0: "2B"}, + } return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "encoder_hidden_states": {0: "2B"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - "text_embeds": {0: "2B"}, - "time_ids": {0: "2B"}, + "sample": {0: "B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "B"}, + "latent": {0: "B", 2: "H", 3: "W"}, + "text_embeds": {0: "B"}, + "time_ids": {0: "B"}, } def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): @@ -747,49 +791,52 @@ def get_input_profile(self, batch_size, image_height, image_width, static_batch, min_latent_width, max_latent_width, ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + m = self.batch_multiplier return { "sample": [ - (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), - (2 * batch_size, self.unet_dim, latent_height, latent_width), - (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), + (m * min_batch, self.unet_dim, min_latent_height, min_latent_width), + (m * batch_size, self.unet_dim, latent_height, latent_width), + (m * max_batch, self.unet_dim, max_latent_height, max_latent_width), ], "encoder_hidden_states": [ - (2 * min_batch, self.text_maxlen, self.embedding_dim), - (2 * batch_size, self.text_maxlen, self.embedding_dim), - (2 * max_batch, self.text_maxlen, self.embedding_dim), + (m * min_batch, self.text_maxlen, self.embedding_dim), + (m * batch_size, self.text_maxlen, self.embedding_dim), + (m * max_batch, self.text_maxlen, self.embedding_dim), ], - "text_embeds": [(2 * min_batch, 1280), (2 * batch_size, 1280), (2 * max_batch, 1280)], + "text_embeds": [(m * min_batch, 1280), (m * batch_size, 1280), (m * max_batch, 1280)], "time_ids": [ - (2 * min_batch, self.time_dim), - (2 * batch_size, self.time_dim), - (2 * max_batch, self.time_dim), + (m * min_batch, self.time_dim), + (m * batch_size, self.time_dim), + (m * max_batch, self.time_dim), ], } def get_shape_dict(self, batch_size, image_height, image_width): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + m = self.batch_multiplier return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "sample": (m * batch_size, self.unet_dim, latent_height, latent_width), "timestep": (1,), - "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), - "latent": (2 * batch_size, 4, latent_height, latent_width), - "text_embeds": (2 * batch_size, 1280), - "time_ids": (2 * batch_size, self.time_dim), + "encoder_hidden_states": (m * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (m * batch_size, 4, latent_height, latent_width), + "text_embeds": (m * batch_size, 1280), + "time_ids": (m * batch_size, self.time_dim), } def get_sample_input(self, batch_size, image_height, image_width): latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) dtype = torch.float16 if self.fp16 else torch.float32 + m = self.batch_multiplier return ( torch.randn( - 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + m * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device ), torch.tensor([1.0], dtype=torch.float32, device=self.device), - torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), { "added_cond_kwargs": { - "text_embeds": torch.randn(2 * batch_size, 1280, dtype=dtype, device=self.device), - "time_ids": torch.randn(2 * batch_size, self.time_dim, dtype=dtype, device=self.device), + "text_embeds": torch.randn(m * batch_size, 1280, dtype=dtype, device=self.device), + "time_ids": torch.randn(m * batch_size, self.time_dim, dtype=dtype, device=self.device), } }, ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py index ec3041e134e75..6932c8056cf78 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py @@ -44,7 +44,6 @@ def __init__( alphas = 1.0 - betas self.alphas_cumprod = torch.cumprod(alphas, dim=0) - # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 @@ -71,7 +70,7 @@ def configure(self): self.variance = torch.from_numpy(variance).to(self.device) timesteps = self.timesteps.long().cpu() - self.alphas_cumprod = self.alphas_cumprod[timesteps].to(self.device) + self.filtered_alphas_cumprod = self.alphas_cumprod[timesteps].to(self.device) self.final_alpha_cumprod = self.final_alpha_cumprod.to(self.device) def scale_model_input(self, sample: torch.FloatTensor, idx, *args, **kwargs) -> torch.FloatTensor: @@ -124,9 +123,9 @@ def step( # - pred_prev_sample -> "x_t-1" prev_idx = idx + 1 - alpha_prod_t = self.alphas_cumprod[idx] + alpha_prod_t = self.filtered_alphas_cumprod[idx] alpha_prod_t_prev = ( - self.alphas_cumprod[prev_idx] if prev_idx < self.num_inference_steps else self.final_alpha_cumprod + self.filtered_alphas_cumprod[prev_idx] if prev_idx < self.num_inference_steps else self.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t @@ -179,15 +178,15 @@ def step( variance_noise = torch.randn( model_output.shape, generator=generator, device=device, dtype=model_output.dtype ) - variance = variance ** (0.5) * eta * variance_noise + variance = std_dev_t * variance_noise prev_sample = prev_sample + variance return prev_sample def add_noise(self, init_latents, noise, idx, latent_timestep): - sqrt_alpha_prod = self.alphas_cumprod[idx] ** 0.5 - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[idx]) ** 0.5 + sqrt_alpha_prod = self.filtered_alphas_cumprod[idx] ** 0.5 + sqrt_one_minus_alpha_prod = (1 - self.filtered_alphas_cumprod[idx]) ** 0.5 noisy_latents = sqrt_alpha_prod * init_latents + sqrt_one_minus_alpha_prod * noise return noisy_latents @@ -720,3 +719,228 @@ def configure(self): def __len__(self): return self.num_train_timesteps + + +# Modified from diffusers.schedulers.LCMScheduler +class LCMScheduler: + def __init__( + self, + device="cuda", + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.012, + original_inference_steps: int = 50, + clip_sample: bool = False, + clip_sample_range: float = 1.0, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + timestep_scaling: float = 10.0, + ): + self.device = device + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.final_alpha_cumprod = self.alphas_cumprod[0] + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + self.num_train_timesteps = num_train_timesteps + self.clip_sample = clip_sample + self.clip_sample_range = clip_sample_range + self.steps_offset = steps_offset + self.prediction_type = prediction_type + self.thresholding = thresholding + self.timestep_spacing = timestep_spacing + self.timestep_scaling = timestep_scaling + self.original_inference_steps = original_inference_steps + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.sample_max_value = sample_max_value + + self._step_index = None + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + if len(index_candidates) > 1: + step_index = index_candidates[1] + else: + step_index = index_candidates[0] + + self._step_index = step_index.item() + + @property + def step_index(self): + return self._step_index + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def set_timesteps( + self, + num_inference_steps: int, + strength: int = 1.0, + ): + assert num_inference_steps <= self.num_train_timesteps + + self.num_inference_steps = num_inference_steps + original_steps = self.original_inference_steps + + assert original_steps <= self.num_train_timesteps + assert num_inference_steps <= original_steps + + # LCM Timesteps Setting + # Currently, only linear spacing is supported. + c = self.num_train_timesteps // original_steps + # LCM Training Steps Schedule + lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * c - 1 + skipping_step = len(lcm_origin_timesteps) // num_inference_steps + # LCM Inference Steps Schedule + timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] + + self.timesteps = torch.from_numpy(timesteps.copy()).to(device=self.device, dtype=torch.long) + + self._step_index = None + + def get_scalings_for_boundary_condition_discrete(self, timestep): + self.sigma_data = 0.5 # Default: 0.5 + scaled_timestep = timestep * self.timestep_scaling + + c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5 + return c_skip, c_out + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + ): + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # 1. get previous step value + prev_step_index = self.step_index + 1 + if prev_step_index < len(self.timesteps): + prev_timestep = self.timesteps[prev_step_index] + else: + prev_timestep = timestep + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 3. Get scalings for boundary conditions + c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) + + # 4. Compute the predicted original sample x_0 based on the model parameterization + if self.prediction_type == "epsilon": # noise-prediction + predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() + elif self.prediction_type == "sample": # x-prediction + predicted_original_sample = model_output + elif self.prediction_type == "v_prediction": # v-prediction + predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for `LCMScheduler`." + ) + + # 5. Clip or threshold "predicted x_0" + if self.thresholding: + predicted_original_sample = self._threshold_sample(predicted_original_sample) + elif self.clip_sample: + predicted_original_sample = predicted_original_sample.clamp(-self.clip_sample_range, self.clip_sample_range) + + # 6. Denoise model output using boundary conditions + denoised = c_out * predicted_original_sample + c_skip * sample + + # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference + # Noise is not used on the final timestep of the timestep schedule. + # This also means that noise is not used for one-step sampling. + if self.step_index != self.num_inference_steps - 1: + noise = torch.randn( + model_output.shape, device=model_output.device, dtype=denoised.dtype, generator=generator + ) + prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise + else: + prev_sample = denoised + + # upon completion increase step index by one + self._step_index += 1 + + return (prev_sample,) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def configure(self): + pass + + def __len__(self): + return self.num_train_timesteps diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py index dfdfa007d74eb..fac72be346b3d 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -60,15 +60,8 @@ def __init__( self.torch_device = torch.device(device, torch.cuda.current_device()) self.stages = pipeline_info.stages() - # TODO: use custom fp16 for ORT_TRT, and no need to fallback to torch. - self.vae_torch_fallback = self.pipeline_info.is_xl() and engine_type != EngineType.ORT_CUDA - - # For SD XL, use an VAE that modified to run in fp16 precision without generating NaNs. - self.custom_fp16_vae = ( - "madebyollin/sdxl-vae-fp16-fix" - if self.pipeline_info.is_xl() and self.engine_type == EngineType.ORT_CUDA - else None - ) + self.vae_torch_fallback = self.pipeline_info.vae_torch_fallback() + self.custom_fp16_vae = self.pipeline_info.custom_fp16_vae() self.models = {} self.engines = {} @@ -84,6 +77,12 @@ def teardown(self): self.engines = {} def get_cached_model_name(self, model_name): + # TODO(tianleiwu): save custom model to a directory named by its original model. + if model_name == "unetxl" and self.pipeline_info.custom_unet(): + model_name = "lcm_" + model_name + + # TODO: When we support original VAE, we shall save custom VAE to another directory. + if self.pipeline_info.is_inpaint(): model_name += "_inpaint" return model_name @@ -100,6 +99,7 @@ def get_engine_path(self, engine_dir, model_name, profile_id): def load_models(self, framework_model_dir: str): # Disable torch SDPA since torch 2.0.* cannot export it to ONNX + # TODO(tianleiwu): Test and remove it if this is not needed in Torch 2.1. if hasattr(torch.nn.functional, "scaled_dot_product_attention"): delattr(torch.nn.functional, "scaled_dot_product_attention") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py index 07c675b2ed990..a03ca7ce2912c 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py @@ -159,9 +159,6 @@ def build_engines( framework_model_dir: str, onnx_dir: str, onnx_opset_version: int = 17, - opt_image_height: int = 512, - opt_image_width: int = 512, - opt_batch_size: int = 1, force_engine_rebuild: bool = False, device_id: int = 0, save_fp32_intermediate_model=False, @@ -209,7 +206,8 @@ def build_engines( with torch.inference_mode(): # For CUDA EP, export FP32 onnx since some graph fusion only supports fp32 graph pattern. - inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) + # Export model with sample of batch size 1, image size 512 x 512 + inputs = model_obj.get_sample_input(1, 512, 512) torch.onnx.export( model, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py index 8a39dc2ed63fc..d966833aba394 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py @@ -13,6 +13,7 @@ from diffusion_models import PipelineInfo from engine_builder import EngineBuilder, EngineType from ort_utils import CudaSession +from packaging import version import onnxruntime as ort @@ -20,7 +21,17 @@ class OrtTensorrtEngine(CudaSession): - def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, workspace_size, enable_cuda_graph): + def __init__( + self, + engine_path, + device_id, + onnx_path, + fp16, + input_profile, + workspace_size, + enable_cuda_graph, + timing_cache_path=None, + ): self.engine_path = engine_path self.ort_trt_provider_options = self.get_tensorrt_provider_options( input_profile, @@ -28,6 +39,7 @@ def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, works fp16, device_id, enable_cuda_graph, + timing_cache_path=timing_cache_path, ) session_options = ort.SessionOptions() @@ -45,7 +57,9 @@ def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, works device = torch.device("cuda", device_id) super().__init__(ort_session, device, enable_cuda_graph) - def get_tensorrt_provider_options(self, input_profile, workspace_size, fp16, device_id, enable_cuda_graph): + def get_tensorrt_provider_options( + self, input_profile, workspace_size, fp16, device_id, enable_cuda_graph, timing_cache_path=None + ): trt_ep_options = { "device_id": device_id, "trt_fp16_enable": fp16, @@ -55,6 +69,9 @@ def get_tensorrt_provider_options(self, input_profile, workspace_size, fp16, dev "trt_engine_cache_path": self.engine_path, } + if version.parse(ort.__version__) > version.parse("1.16.2") and timing_cache_path is not None: + trt_ep_options["trt_timing_cache_path"] = timing_cache_path + if enable_cuda_graph: trt_ep_options["trt_cuda_graph_enable"] = True @@ -153,6 +170,7 @@ def build_engines( static_image_shape=True, max_workspace_size=0, device_id=0, + timing_cache=None, ): self.torch_device = torch.device("cuda", device_id) self.load_models(framework_model_dir) @@ -224,7 +242,6 @@ def build_engines( engine_path = self.get_engine_path(engine_dir, model_name, profile_id) onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) - if not self.has_engine_file(engine_path): logger.info( "Building TensorRT engine for %s from %s to %s. It can take a while to complete...", @@ -251,6 +268,7 @@ def build_engines( input_profile=input_profile, workspace_size=self.get_work_space_size(model_name, max_workspace_size), enable_cuda_graph=self.use_cuda_graph, + timing_cache_path=timing_cache, ) built_engines[model_name] = engine diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py index 4b48396b6c783..28e79abb9f018 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py @@ -91,7 +91,10 @@ def optimize( if keep_outputs: m.prune_graph(outputs=keep_outputs) - use_external_data_format = m.model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF + model_size = m.model.ByteSize() + + # model size might be negative (overflow?) in Windows. + use_external_data_format = model_size <= 0 or model_size >= onnx.checker.MAXIMUM_PROTOBUF # Note that ORT < 1.16 could not save model larger than 2GB. # This step is is optional since it has no impact on inference latency. diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py index faa3f8bfaabf1..31ede1ba901f2 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py @@ -68,6 +68,7 @@ def _infer( image_height, image_width, denoising_steps=30, + strength=0.3, guidance=5.0, seed=None, warmup=False, @@ -79,7 +80,6 @@ def _infer( crops_coords_top_left = (0, 0) target_size = (image_height, image_width) - strength = 0.3 aesthetic_score = 6.0 negative_aesthetic_score = 2.5 @@ -155,12 +155,12 @@ def _infer( torch.cuda.synchronize() e2e_toc = time.perf_counter() + perf_data = None if not warmup: print("SD-XL Refiner Pipeline") - self.print_summary(e2e_tic, e2e_toc, batch_size) - self.save_images(images, "img2img-xl", prompt) + perf_data = self.print_summary(e2e_tic, e2e_toc, batch_size) - return images, (e2e_toc - e2e_tic) * 1000.0 + return images, perf_data def run( self, @@ -171,6 +171,7 @@ def run( image_width, denoising_steps=30, guidance=5.0, + strength=0.3, seed=None, warmup=False, return_type="image", @@ -213,6 +214,7 @@ def run( image_height, image_width, denoising_steps=denoising_steps, + strength=strength, guidance=guidance, seed=seed, warmup=warmup, @@ -226,6 +228,7 @@ def run( image_height, image_width, denoising_steps=denoising_steps, + strength=strength, guidance=guidance, seed=seed, warmup=warmup, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py index e34fab1218b21..a0b3c3a1c85b1 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py @@ -23,12 +23,13 @@ import os import pathlib import random +from typing import Any, Dict, List import nvtx import torch from cuda import cudart from diffusion_models import PipelineInfo, get_tokenizer -from diffusion_schedulers import DDIMScheduler, EulerAncestralDiscreteScheduler, UniPCMultistepScheduler +from diffusion_schedulers import DDIMScheduler, EulerAncestralDiscreteScheduler, LCMScheduler, UniPCMultistepScheduler from engine_builder import EngineType from engine_builder_ort_cuda import OrtCudaEngineBuilder from engine_builder_ort_trt import OrtTensorrtEngineBuilder @@ -63,7 +64,7 @@ def __init__( max_batch_size (int): Maximum batch size for dynamic batch engine. scheduler (str): - The scheduler to guide the denoising process. Must be one of [DDIM, EulerA, UniPC]. + The scheduler to guide the denoising process. Must be one of [DDIM, EulerA, UniPC, LCM]. device (str): PyTorch device to run inference. Default: 'cuda' output_dir (str): @@ -102,35 +103,19 @@ def __init__( self.verbose = verbose self.nvtx_profile = nvtx_profile - # Scheduler options - sched_opts = {"num_train_timesteps": 1000, "beta_start": 0.00085, "beta_end": 0.012} - if self.version in ("2.0", "2.1"): - sched_opts["prediction_type"] = "v_prediction" - else: - sched_opts["prediction_type"] = "epsilon" - - if scheduler == "DDIM": - self.scheduler = DDIMScheduler(device=self.device, **sched_opts) - elif scheduler == "EulerA": - self.scheduler = EulerAncestralDiscreteScheduler(device=self.device, **sched_opts) - elif scheduler == "UniPC": - self.scheduler = UniPCMultistepScheduler(device=self.device) - else: - raise ValueError("Scheduler should be either DDIM, EulerA or UniPC") - self.stages = pipeline_info.stages() - self.vae_torch_fallback = self.pipeline_info.is_xl() - self.use_cuda_graph = use_cuda_graph self.tokenizer = None self.tokenizer2 = None - self.generator = None - self.denoising_steps = None + self.generator = torch.Generator(device="cuda") self.actual_steps = None + self.current_scheduler = None + self.set_scheduler(scheduler) + # backend engine self.engine_type = engine_type if engine_type == EngineType.TRT: @@ -162,10 +147,36 @@ def __init__( def is_backend_tensorrt(self): return self.engine_type == EngineType.TRT + def set_scheduler(self, scheduler: str): + if scheduler == self.current_scheduler: + return + + # Scheduler options + sched_opts = {"num_train_timesteps": 1000, "beta_start": 0.00085, "beta_end": 0.012} + if self.version in ("2.0", "2.1"): + sched_opts["prediction_type"] = "v_prediction" + else: + sched_opts["prediction_type"] = "epsilon" + + if scheduler == "DDIM": + self.scheduler = DDIMScheduler(device=self.device, **sched_opts) + elif scheduler == "EulerA": + self.scheduler = EulerAncestralDiscreteScheduler(device=self.device, **sched_opts) + elif scheduler == "UniPC": + self.scheduler = UniPCMultistepScheduler(device=self.device, **sched_opts) + elif scheduler == "LCM": + self.scheduler = LCMScheduler(device=self.device, **sched_opts) + else: + raise ValueError("Scheduler should be either DDIM, EulerA, UniPC or LCM") + + self.current_scheduler = scheduler + self.denoising_steps = None + def set_denoising_steps(self, denoising_steps: int): - self.scheduler.set_timesteps(denoising_steps) - self.scheduler.configure() - self.denoising_steps = denoising_steps + if not (self.denoising_steps == denoising_steps and isinstance(self.scheduler, DDIMScheduler)): + self.scheduler.set_timesteps(denoising_steps) + self.scheduler.configure() + self.denoising_steps = denoising_steps def load_resources(self, image_height, image_width, batch_size): # If engine is built with static input shape, call this only once after engine build. @@ -173,8 +184,13 @@ def load_resources(self, image_height, image_width, batch_size): self.backend.load_resources(image_height, image_width, batch_size) def set_random_seed(self, seed): - # Initialize noise generator. Usually, it is done before a batch of inference. - self.generator = torch.Generator(device="cuda").manual_seed(seed) if isinstance(seed, int) else None + if isinstance(seed, int): + self.generator.manual_seed(seed) + else: + self.generator.seed() + + def get_current_seed(self): + return self.generator.initial_seed() def teardown(self): for e in self.events.values(): @@ -225,6 +241,7 @@ def encode_prompt( pooled_outputs=False, output_hidden_states=False, force_zeros_for_empty_prompt=False, + do_classifier_free_guidance=True, ): if tokenizer is None: tokenizer = self.tokenizer @@ -252,41 +269,44 @@ def encode_prompt( if output_hidden_states: hidden_states = outputs["hidden_states"].clone() - # Note: negative prompt embedding is not needed for SD XL when guidance < 1 - - # For SD XL base, handle force_zeros_for_empty_prompt - is_empty_negative_prompt = all([not i for i in negative_prompt]) - if force_zeros_for_empty_prompt and is_empty_negative_prompt: - uncond_embeddings = torch.zeros_like(text_embeddings) - if output_hidden_states: - uncond_hidden_states = torch.zeros_like(hidden_states) - else: - # Tokenize negative prompt - uncond_input_ids = ( - tokenizer( - negative_prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", + # Note: negative prompt embedding is not needed for SD XL when guidance <= 1 + if do_classifier_free_guidance: + # For SD XL base, handle force_zeros_for_empty_prompt + is_empty_negative_prompt = all([not i for i in negative_prompt]) + if force_zeros_for_empty_prompt and is_empty_negative_prompt: + uncond_embeddings = torch.zeros_like(text_embeddings) + if output_hidden_states: + uncond_hidden_states = torch.zeros_like(hidden_states) + else: + # Tokenize negative prompt + uncond_input_ids = ( + tokenizer( + negative_prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.device) ) - .input_ids.type(torch.int32) - .to(self.device) - ) - outputs = self.run_engine(encoder, {"input_ids": uncond_input_ids}) - uncond_embeddings = outputs["text_embeddings"] - if output_hidden_states: - uncond_hidden_states = outputs["hidden_states"] + outputs = self.run_engine(encoder, {"input_ids": uncond_input_ids}) + uncond_embeddings = outputs["text_embeddings"] + if output_hidden_states: + uncond_hidden_states = outputs["hidden_states"] - # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) if pooled_outputs: pooled_output = text_embeddings if output_hidden_states: - text_embeddings = torch.cat([uncond_hidden_states, hidden_states]).to(dtype=torch.float16) + if do_classifier_free_guidance: + text_embeddings = torch.cat([uncond_hidden_states, hidden_states]).to(dtype=torch.float16) + else: + text_embeddings = hidden_states.to(dtype=torch.float16) cudart.cudaEventRecord(self.events["clip-stop"], 0) if self.nvtx_profile: @@ -308,7 +328,7 @@ def denoise_latent( guidance=7.5, add_kwargs=None, ): - assert guidance > 1.0, "Guidance has to be > 1.0" # TODO: remove this constraint + do_classifier_free_guidance = guidance > 1.0 cudart.cudaEventRecord(self.events["denoise-start"], 0) if not isinstance(timesteps, torch.Tensor): @@ -319,7 +339,7 @@ def denoise_latent( nvtx_latent_scale = nvtx.start_range(message="latent_scale", color="pink") # Expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input( latent_model_input, step_offset + step_index, timestep @@ -353,11 +373,14 @@ def denoise_latent( nvtx_latent_step = nvtx.start_range(message="latent_step", color="pink") # perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond) if type(self.scheduler) == UniPCMultistepScheduler: latents = self.scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] + elif type(self.scheduler) == LCMScheduler: + latents = self.scheduler.step(noise_pred, timestep, latents, generator=self.generator)[0] else: latents = self.scheduler.step(noise_pred, latents, step_offset + step_index, timestep) @@ -393,38 +416,42 @@ def decode_latent(self, latents): nvtx.end_range(nvtx_vae) return images - def print_summary(self, tic, toc, batch_size, vae_enc=False): + def print_summary(self, tic, toc, batch_size, vae_enc=False) -> Dict[str, Any]: + throughput = batch_size / (toc - tic) + latency_clip = cudart.cudaEventElapsedTime(self.events["clip-start"], self.events["clip-stop"])[1] + latency_unet = cudart.cudaEventElapsedTime(self.events["denoise-start"], self.events["denoise-stop"])[1] + latency_vae = cudart.cudaEventElapsedTime(self.events["vae-start"], self.events["vae-stop"])[1] + latency_vae_encoder = ( + cudart.cudaEventElapsedTime(self.events["vae_encoder-start"], self.events["vae_encoder-stop"])[1] + if vae_enc + else None + ) + latency = (toc - tic) * 1000.0 + print("|------------|--------------|") print("| {:^10} | {:^12} |".format("Module", "Latency")) print("|------------|--------------|") if vae_enc: - print( - "| {:^10} | {:>9.2f} ms |".format( - "VAE-Enc", - cudart.cudaEventElapsedTime(self.events["vae_encoder-start"], self.events["vae_encoder-stop"])[1], - ) - ) - print( - "| {:^10} | {:>9.2f} ms |".format( - "CLIP", cudart.cudaEventElapsedTime(self.events["clip-start"], self.events["clip-stop"])[1] - ) - ) - print( - "| {:^10} | {:>9.2f} ms |".format( - "UNet x " + str(self.actual_steps), - cudart.cudaEventElapsedTime(self.events["denoise-start"], self.events["denoise-stop"])[1], - ) - ) - print( - "| {:^10} | {:>9.2f} ms |".format( - "VAE-Dec", cudart.cudaEventElapsedTime(self.events["vae-start"], self.events["vae-stop"])[1] - ) - ) + print("| {:^10} | {:>9.2f} ms |".format("VAE-Enc", latency_vae_encoder)) + print("| {:^10} | {:>9.2f} ms |".format("CLIP", latency_clip)) + print("| {:^10} | {:>9.2f} ms |".format("UNet x " + str(self.actual_steps), latency_unet)) + print("| {:^10} | {:>9.2f} ms |".format("VAE-Dec", latency_vae)) print("|------------|--------------|") - print("| {:^10} | {:>9.2f} ms |".format("Pipeline", (toc - tic) * 1000.0)) + print("| {:^10} | {:>9.2f} ms |".format("Pipeline", latency)) print("|------------|--------------|") - print(f"Throughput: {batch_size / (toc - tic):.2f} image/s") + print(f"Throughput: {throughput:.2f} image/s") + + perf_data = { + "latency_clip": latency_clip, + "latency_unet": latency_unet, + "latency_vae": latency_vae, + "latency": latency, + "throughput": throughput, + } + if vae_enc: + perf_data["latency_vae_encoder"] = latency_vae_encoder + return perf_data @staticmethod def to_pil_image(images): @@ -436,16 +463,31 @@ def to_pil_image(images): return [Image.fromarray(images[i]) for i in range(images.shape[0])] - def save_images(self, images, pipeline, prompt): - image_name_prefix = ( - pipeline + "".join(set(["-" + prompt[i].replace(" ", "_")[:10] for i in range(len(prompt))])) + "-" - ) + def metadata(self) -> Dict[str, Any]: + return { + "actual_steps": self.actual_steps, + "seed": self.get_current_seed(), + "name": self.pipeline_info.name(), + "custom_vae": self.pipeline_info.custom_fp16_vae(), + "custom_unet": self.pipeline_info.custom_unet(), + } + def save_images(self, images: List, prompt: List[str], negative_prompt: List[str], metadata: Dict[str, Any]): images = self.to_pil_image(images) - random_session_id = str(random.randint(1000, 9999)) + session_id = str(random.randint(1000, 9999)) for i, image in enumerate(images): - image_path = os.path.join( - self.output_dir, image_name_prefix + str(i + 1) + "-" + random_session_id + ".png" - ) + seed = str(self.get_current_seed()) + prefix = "".join(x for x in prompt[i] if x.isalnum() or x in ", -").replace(" ", "_")[:20] + parts = [prefix, session_id, str(i + 1), str(seed), self.current_scheduler, str(self.actual_steps)] + image_path = os.path.join(self.output_dir, "-".join(parts) + ".png") print(f"Saving image {i+1} / {len(images)} to: {image_path}") - image.save(image_path) + + from PIL import PngImagePlugin + + info = PngImagePlugin.PngInfo() + for k, v in metadata.items(): + info.add_text(k, str(v)) + info.add_text("prompt", prompt[i]) + info.add_text("negative_prompt", negative_prompt[i]) + + image.save(image_path, "PNG", pnginfo=info) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py index b9759b44e7635..87ce85af247a5 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py @@ -84,11 +84,11 @@ def _infer( torch.cuda.synchronize() e2e_toc = time.perf_counter() + perf_data = None if not warmup: - self.print_summary(e2e_tic, e2e_toc, batch_size) - self.save_images(images, "txt2img", prompt) + perf_data = self.print_summary(e2e_tic, e2e_toc, batch_size) - return images, (e2e_toc - e2e_tic) * 1000.0 + return images, perf_data def run( self, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py index 1b3be143e6ce7..8ed7e20e94c07 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py @@ -62,7 +62,7 @@ def _infer( return_type="image", ): assert len(prompt) == len(negative_prompt) - + do_classifier_free_guidance = guidance > 1.0 original_size = (image_height, image_width) crops_coords_top_left = (0, 0) target_size = (image_height, image_width) @@ -91,6 +91,7 @@ def _infer( tokenizer=self.tokenizer, output_hidden_states=True, force_zeros_for_empty_prompt=True, + do_classifier_free_guidance=do_classifier_free_guidance, ) # CLIP text encoder 2 text_embeddings2, pooled_embeddings2 = self.encode_prompt( @@ -101,6 +102,7 @@ def _infer( pooled_outputs=True, output_hidden_states=True, force_zeros_for_empty_prompt=True, + do_classifier_free_guidance=do_classifier_free_guidance, ) # Merged text embeddings @@ -111,9 +113,10 @@ def _infer( original_size, crops_coords_top_left, target_size, dtype=text_embeddings.dtype ) add_time_ids = add_time_ids.repeat(batch_size, 1) - add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0).to(self.device) + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) - add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids} + add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids.to(self.device)} # UNet denoiser latents = self.denoise_latent( @@ -133,13 +136,12 @@ def _infer( torch.cuda.synchronize() e2e_toc = time.perf_counter() + perf_data = None if not warmup: print("SD-XL Base Pipeline") - self.print_summary(e2e_tic, e2e_toc, batch_size) - if return_type != "latent": - self.save_images(images, "txt2img-xl", prompt) + perf_data = self.print_summary(e2e_tic, e2e_toc, batch_size) - return images, (e2e_toc - e2e_tic) * 1000.0 + return images, perf_data def run( self, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt index 5f908c4f5ff39..447cb54f98ed2 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda11.txt @@ -1,7 +1,7 @@ -r requirements.txt # Official onnxruntime-gpu 1.16.1 is built with CUDA 11.8. -onnxruntime-gpu>=1.16.1 +onnxruntime-gpu>=1.16.2 py3nvml @@ -12,7 +12,8 @@ cuda-python==11.8.0 # For windows, cuda-python need the following pywin32; platform_system == "Windows" -nvtx +# For windows, run `conda install -c conda-forge nvtx` instead +nvtx; platform_system != "Windows" # Please install PyTorch 2.1 or above for CUDA 11.8 using one of the following commands: # pip3 install torch --index-url https://download.pytorch.org/whl/cu118 diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt index e4e765831c1b3..1ff0e3c1cf5af 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda12.txt @@ -1,18 +1,19 @@ -r requirements.txt # For CUDA 12.*, you will need build onnxruntime-gpu from source and install the wheel. See README.md for detail. -# onnxruntime-gpu>=1.16.1 +# onnxruntime-gpu>=1.16.2 py3nvml # The version of cuda-python shall be compatible with installed CUDA version. # For example, if your CUDA version is 12.1, you can install cuda-python 12.1. -cuda-python==12.1.0 +cuda-python>=12.1.0 # For windows, cuda-python need the following pywin32; platform_system == "Windows" -nvtx +# For windows, run `conda install -c conda-forge nvtx` instead +nvtx; platform_system != "Windows" # Please install PyTorch 2.1 or above for 12.1 using one of the following commands: # pip3 install torch --index-url https://download.pytorch.org/whl/cu121 diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt index 9386a941fb323..63fa8acfbcc95 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt @@ -1,8 +1,8 @@ -diffusers>=0.19.3 -transformers>=4.31.0 +diffusers==0.23.1 +transformers==4.35.1 numpy>=1.24.1 accelerate -onnx>=1.13.0 +onnx==1.14.1 coloredlogs packaging # Use newer version of protobuf might cause crash @@ -10,6 +10,8 @@ protobuf==3.20.3 psutil sympy # The following are for SDXL -optimum>=1.11.1 +optimum==1.13.1 safetensors invisible_watermark +# newer version of opencv-python migth encounter module 'cv2.dnn' has no attribute 'DictValue' error +opencv-python==4.8.0.74 diff --git a/onnxruntime/python/tools/transformers/onnx_model_conformer.py b/onnxruntime/python/tools/transformers/onnx_model_conformer.py new file mode 100644 index 0000000000000..1506d85f53fd4 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_model_conformer.py @@ -0,0 +1,33 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +from typing import Optional + +from fusion_attention import AttentionMask +from fusion_conformer_attention import FusionConformerAttention +from fusion_options import FusionOptions +from onnx_model_bert import BertOnnxModel + +logger = logging.getLogger(__name__) + + +class ConformerOnnxModel(BertOnnxModel): + def __init__(self, model, num_heads, hidden_size): + super().__init__(model, num_heads, hidden_size) + self.attention_mask = AttentionMask(self) + self.attention_fusion = FusionConformerAttention(self, self.hidden_size, self.num_heads, self.attention_mask) + + def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): + self.attention_fusion.use_multi_head_attention = False if options is None else options.use_multi_head_attention + self.attention_fusion.disable_multi_head_attention_bias = ( + False if options is None else options.disable_multi_head_attention_bias + ) + super().optimize(options, add_dynamic_axes) + + def fuse_attention(self): + self.attention_fusion.apply() + + def preprocess(self): + self.adjust_reshape_and_expand() diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 94a757320e598..6842a97fe0c77 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -32,6 +32,7 @@ from onnx_model_bert_keras import BertOnnxModelKeras from onnx_model_bert_tf import BertOnnxModelTF from onnx_model_clip import ClipOnnxModel +from onnx_model_conformer import ConformerOnnxModel from onnx_model_gpt2 import Gpt2OnnxModel from onnx_model_t5 import T5OnnxModel from onnx_model_tnlr import TnlrOnnxModel @@ -56,6 +57,7 @@ "unet": (UnetOnnxModel, "pytorch", 1), # UNet in Stable Diffusion "vae": (VaeOnnxModel, "pytorch", 1), # UAE in Stable Diffusion "vit": (BertOnnxModel, "pytorch", 1), + "conformer": (ConformerOnnxModel, "pytorch", 1), } diff --git a/onnxruntime/test/contrib_ops/gemm_float8_test.cc b/onnxruntime/test/contrib_ops/gemm_float8_test.cc new file mode 100644 index 0000000000000..c022736075cde --- /dev/null +++ b/onnxruntime/test/contrib_ops/gemm_float8_test.cc @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +#if defined(USE_CUDA) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 + +TEST(GemmFloat8OpTest, BFloat16) { + OpTester test("GemmFloat8", 1, onnxruntime::kMSDomain); + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddAttribute("activation", "NONE"); + test.AddAttribute("dtype", static_cast(ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)); + test.AddInput("A", {2, 4}, MakeBFloat16({1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f})); + test.AddInput("B", {4, 3}, MakeBFloat16({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})); + test.AddInput("C", {2, 3}, MakeBFloat16({1.f, 1.f, 1.f, 1.f, 1.f, 1.f})); + test.AddOutput("Y", {2, 3}, MakeBFloat16({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f})); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(GemmFloat8OpTest, Float) { + OpTester test("GemmFloat8", 1, onnxruntime::kMSDomain); + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddAttribute("activation", "NONE"); + test.AddAttribute("dtype", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + test.AddInput("A", {2, 4}, std::vector({1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f})); + test.AddInput("B", {4, 3}, std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})); + test.AddInput("C", {2, 3}, std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f})); + test.AddOutput("Y", {2, 3}, std::vector({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f})); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +std::vector _Cvt(const std::vector& tensor) { + std::vector fp16_data(tensor.size()); + ConvertFloatToMLFloat16(tensor.data(), fp16_data.data(), static_cast(tensor.size())); + return fp16_data; +} + +TEST(GemmFloat8OpTest, Float16) { + OpTester test("GemmFloat8", 1, onnxruntime::kMSDomain); + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddAttribute("activation", "NONE"); + test.AddAttribute("dtype", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); + test.AddInput("A", {2, 4}, _Cvt(std::vector({1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}))); + test.AddInput("B", {4, 3}, _Cvt(std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + test.AddInput("C", {2, 3}, _Cvt(std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + test.AddOutput("Y", {2, 3}, _Cvt(std::vector({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f}))); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +#if (!defined(DISABLE_FLOAT8_TYPES)) && (CUDA_VERSION >= 12000) + +template +std::vector _TypedCvt(const std::vector& tensor); + +template <> +std::vector _TypedCvt(const std::vector& tensor) { + return tensor; +} + +template <> +std::vector _TypedCvt(const std::vector& tensor) { + std::vector out(tensor.size()); + for (size_t i = 0; i < tensor.size(); ++i) { + out[i] = Float8E4M3FN(tensor[i]); + } + return out; +} + +template +void TestGemmFloat8WithFloat8(int64_t dtype) { + int min_cuda_architecture = 11080; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support Matrix Multiplication for FLOAT8"; + return; + } + OpTester test("GemmFloat8", 1, onnxruntime::kMSDomain); + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)1); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddAttribute("activation", "NONE"); + test.AddAttribute("dtype", dtype); + test.AddInput("A", {2, 4}, _TypeCvt(std::vector({1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}))); + test.AddInput("B", {3, 4}, _TypeCvt(std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + test.AddInput("C", {2, 3}, _TypeCvt(std::vector({1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + test.AddOutput("Y", {2, 3}, _TypeCvt(std::vector({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f}))); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(GemmFloat8OpTest, Float8E4M3FNToFloat) { + TestGemmFloat8WithFloat8(static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); +} + +TEST(GemmFloat8OpTest, Float8E4M3FNToFloat8E4M3FN) { + TestGemmFloat8WithFloat8(static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN)); +} + +#endif + +#endif + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 918ee0e6eb976..3c6217915bef0 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -72,21 +72,17 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop MlasTranspose(input1_f_vals.data(), input1_f_vals_trans.data(), K, N); #endif - int meta_rows; - int meta_cols; - MlasBlockwiseQuantMetaShape((int)block_size, true, (int)K, (int)N, meta_rows, meta_cols); + int q_rows, q_cols; + MlasBlockwiseQuantizedShape((int)block_size, true, (int)K, (int)N, q_rows, q_cols); - int q_rows; - int q_cols; - MlasBlockwiseQuantizedShape((int)block_size, true, (int)K, (int)N, q_rows, q_cols); + size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; + MlasBlockwiseQuantizedBufferSizes(4, static_cast(block_size), /* columnwise */ true, + static_cast(K), static_cast(N), + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); - std::vector input1_vals(q_rows * q_cols); - std::vector scales(meta_rows * meta_cols); - - // TODO!! THIS SHOULD BE PROVIDED BY MLAS - // sub 8b packing always happen on the column dimension - const int packed_meta_rows = (meta_rows * QBits + 7) / 8; - std::vector zp(packed_meta_rows * meta_cols); + std::vector input1_vals(q_data_size_in_bytes); + std::vector scales(q_scale_size); + std::vector zp(q_zp_size_in_bytes); QuantizeDequantize(input1_f_vals, input1_vals, @@ -115,9 +111,9 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop if (use_float16) { test.AddInput("A", {M, K}, ToFloat16(input0_vals), false); test.AddInput("B", {q_cols, q_rows}, input1_vals, true); - test.AddInput("scales", {meta_cols * meta_rows}, ToFloat16(scales), true); + test.AddInput("scales", {static_cast(q_scale_size)}, ToFloat16(scales), true); if (has_zeropoint) { - test.AddInput("zero_points", {meta_cols * packed_meta_rows}, zp, true); + test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); } test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); @@ -129,9 +125,9 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop } else { test.AddInput("A", {M, K}, input0_vals, false); test.AddInput("B", {q_cols, q_rows}, input1_vals, true); - test.AddInput("scales", {meta_cols * meta_rows}, scales, true); + test.AddInput("scales", {static_cast(q_scale_size)}, scales, true); if (has_zeropoint) { - test.AddInput("zero_points", {meta_cols * packed_meta_rows}, zp, true); + test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); } test.AddOutput("Y", {M, N}, expected_vals); diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc new file mode 100644 index 0000000000000..ebb0261deefa5 --- /dev/null +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -0,0 +1,423 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +static void RunMoETest( + const std::vector& input, + const std::vector& router_probs, + const std::vector& fc1_experts_weights, + const std::vector& fc2_experts_weights, + const std::vector& fc1_experts_bias, + const std::vector& fc2_experts_bias, + const std::vector& output_data, + int num_rows, + int num_experts, + int hidden_size, + int inter_size, + std::string activation_type, + bool use_float16 = false) { + int min_cuda_architecture = use_float16 ? 530 : 0; + + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + if (enable_cuda) { + OpTester tester("MoE", 1, onnxruntime::kMSDomain); + tester.AddAttribute("k", static_cast(1)); + tester.AddAttribute("activation_type", activation_type); + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + std::vector fc1_experts_weights_dims = {num_experts, hidden_size, inter_size}; + std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size}; + std::vector fc1_experts_bias_dims = {num_experts, inter_size}; + std::vector fc2_experts_bias_dims = {num_experts, hidden_size}; + std::vector output_dims = {num_rows, hidden_size}; + + if (use_float16) { + tester.AddInput("input", input_dims, ToFloat16(input)); + tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, ToFloat16(fc1_experts_weights)); + tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, ToFloat16(fc2_experts_weights)); + tester.AddInput("fc1_experts_bias", fc1_experts_bias_dims, ToFloat16(fc1_experts_bias)); + tester.AddInput("fc2_experts_bias", fc2_experts_bias_dims, ToFloat16(fc2_experts_bias)); + tester.AddOutput("output", output_dims, ToFloat16(output_data)); + } else { + tester.AddInput("input", input_dims, input); + tester.AddInput("router_probs", router_probs_dims, router_probs); + tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + tester.AddInput("fc1_experts_bias", fc1_experts_bias_dims, fc1_experts_bias); + tester.AddInput("fc2_experts_bias", fc2_experts_bias_dims, fc2_experts_bias); + tester.AddOutput("output", output_dims, output_data); + } + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MoETest, MoETest_Gelu) { + int num_rows = 4; + int num_experts = 4; + int hidden_size = 8; + int inter_size = 16; + + const std::vector input = { + -1.1200173f, -0.45884353f, -1.2929888f, 1.0784022f, 0.116372705f, 0.26902613f, -1.8818876f, -0.5457026f, + 0.22222236f, -0.28868636f, 0.6692926f, 1.4944887f, 0.02431708f, -0.49781424f, 0.7378293f, 1.276276f, + -0.15469065f, -0.28456813f, -0.6296439f, -0.24855971f, 0.80565417f, -1.1018785f, -0.74082595f, 0.82407707f, + -0.95033455f, 0.659333f, -0.68629056f, -0.2916592f, 1.869919f, -1.1053563f, -0.14417848f, -0.34625578f}; + const std::vector router_probs = { + -0.84837115f, 0.100507565f, -0.10548311f, 0.40957215f, 1.0159845f, 0.26919764f, 0.021741152f, -0.34184334f, + -0.71324956f, 0.29018253f, -0.18227568f, 0.31496462f, -0.48426327f, -1.006643f, -0.100081146f, -0.07692295f}; + const std::vector fc1_experts_weights = { + 0.14731085f, 0.52229995f, 0.14753294f, 0.22475791f, 0.20864725f, 0.6708725f, 0.20204341f, 0.4890914f, + 0.52103406f, 0.8223115f, 0.122039974f, 0.15674388f, 0.20966923f, 0.8499667f, 0.3202675f, 0.92174435f, + 0.6808038f, 0.563313f, 0.496278f, 0.40115923f, 0.5627332f, 0.38582766f, 0.49648678f, 0.5637965f, + 0.10889745f, 0.23793429f, 0.90374637f, 0.09422666f, 0.4640969f, 0.99461937f, 0.6806185f, 0.5141565f, + 0.066695035f, 0.74768895f, 0.14385962f, 0.35806787f, 0.33224183f, 0.4259563f, 0.50546914f, 0.91240376f, + 0.5624194f, 0.9478464f, 0.8058562f, 0.18389302f, 0.72425205f, 0.14655197f, 0.28808743f, 0.64706135f, + 0.66509604f, 0.875114f, 0.33904207f, 0.50080043f, 0.7574118f, 0.016453922f, 0.8614903f, 0.08653879f, + 0.50689125f, 0.41499162f, 0.23666352f, 0.5660855f, 0.91345936f, 0.35384023f, 0.20315295f, 0.31508058f, + 0.0044258237f, 0.725697f, 0.25986814f, 0.16632986f, 0.21194929f, 0.787478f, 0.76478684f, 0.8837609f, + 0.68136156f, 0.33302015f, 0.36027592f, 0.647715f, 0.91101736f, 0.6359461f, 0.26342732f, 0.2649613f, + 0.02726549f, 0.608024f, 0.21940875f, 0.054212093f, 0.93843824f, 0.1752944f, 0.44311923f, 0.64324677f, + 0.51592916f, 0.16355914f, 0.09583914f, 0.8985412f, 0.58141935f, 0.91481227f, 0.3323797f, 0.6472777f, + 0.3856619f, 0.47776443f, 0.1954779f, 0.66910046f, 0.65808296f, 0.4896857f, 0.38754892f, 0.1917851f, + 0.8457724f, 0.12778795f, 0.70483273f, 0.33187324f, 0.258766f, 0.58982253f, 0.24027151f, 0.6152024f, + 0.5981904f, 0.12875527f, 0.5832493f, 0.7129646f, 0.6979155f, 0.43706065f, 0.09010619f, 0.42292297f, + 0.67365384f, 0.31756145f, 0.68979055f, 0.8329813f, 0.2389242f, 0.5049309f, 0.7067495f, 0.5391889f, + 0.54176575f, 0.5624327f, 0.10692614f, 0.5392941f, 0.8462349f, 0.9505569f, 0.79387546f, 0.5670015f, + 0.7335071f, 0.25676018f, 0.08565581f, 0.07003945f, 0.99880487f, 0.8173947f, 0.15438312f, 0.6956213f, + 0.8775838f, 0.9998074f, 0.93719745f, 0.8873769f, 0.38537037f, 0.32452917f, 0.9105244f, 0.7801898f, + 0.19911051f, 0.9495086f, 0.7415793f, 0.77256775f, 0.18661183f, 0.6434499f, 0.32471877f, 0.8906783f, + 0.4100297f, 0.69465625f, 0.5888109f, 0.7127341f, 0.33008623f, 0.7437857f, 0.15076452f, 0.6129275f, + 0.16170406f, 0.006731212f, 0.09847212f, 0.89473504f, 0.7705178f, 0.96910787f, 0.9005606f, 0.053477287f, + 0.15878445f, 0.4192087f, 0.17528385f, 0.84719825f, 0.121996105f, 0.25604928f, 0.016954303f, 0.21612722f, + 0.91123873f, 0.90938f, 0.85791886f, 0.88606364f, 0.94459325f, 0.3719685f, 0.72000104f, 0.9454652f, + 0.6654094f, 0.9998382f, 0.75933146f, 0.81082416f, 0.32500392f, 0.73991376f, 0.5574533f, 0.38059133f, + 0.21814507f, 0.21944171f, 0.11525959f, 0.83566517f, 0.8554656f, 0.44309366f, 0.210657f, 0.88645273f, + 0.81974447f, 0.537167f, 0.26393235f, 0.9595239f, 0.70447034f, 0.12042731f, 0.97854143f, 0.8796869f, + 0.31775457f, 0.78107727f, 0.21590549f, 0.42164284f, 0.9245506f, 0.52065957f, 0.14639091f, 0.33288354f, + 0.36427742f, 0.4035356f, 0.5478503f, 0.9624148f, 0.5267702f, 0.19128f, 0.52562714f, 0.7397436f, + 0.7480201f, 0.04303074f, 0.41052878f, 0.12842774f, 0.2866572f, 0.6801467f, 0.1449349f, 0.68586344f, + 0.92438906f, 0.5327942f, 0.16675615f, 0.32085752f, 0.60918206f, 0.11884099f, 0.74840516f, 0.04606521f, + 0.01935333f, 0.014169693f, 0.39856833f, 0.83621645f, 0.026760519f, 0.91559356f, 0.29998857f, 0.64644206f, + 0.52280146f, 0.049140453f, 0.9146645f, 0.7692217f, 0.99699783f, 0.7526061f, 0.1699655f, 0.9172919f, + 0.5268722f, 0.73710823f, 0.09908545f, 0.35618675f, 0.009061217f, 0.30525374f, 0.6078656f, 0.10741913f, + 0.6593821f, 0.7684034f, 0.56965464f, 0.16545832f, 0.11234015f, 0.3457417f, 0.7194791f, 0.9931982f, + 0.7875145f, 0.44369537f, 0.6753082f, 0.009468555f, 0.07294935f, 0.73330396f, 0.2167924f, 0.74054784f, + 0.14703393f, 0.25234455f, 0.08815551f, 0.76092035f, 0.44905245f, 0.88480055f, 0.8094361f, 0.7766713f, + 0.51607805f, 0.345411f, 0.39128417f, 0.5664503f, 0.74785477f, 0.14970505f, 0.91963893f, 0.44563496f, + 0.08102721f, 0.22947109f, 0.94240886f, 0.9572636f, 0.036860168f, 0.85264915f, 0.7505796f, 0.79595923f, + 0.9232646f, 0.23052484f, 0.6578879f, 0.7046166f, 0.35225332f, 0.66732657f, 0.3561433f, 0.80913067f, + 0.3612727f, 0.31360215f, 0.6258745f, 0.6773468f, 0.25571418f, 0.54419917f, 0.78976786f, 0.45025164f, + 0.65216696f, 0.3794065f, 0.6752498f, 0.1378029f, 0.2059856f, 0.24620473f, 0.95950544f, 0.36545795f, + 0.49863482f, 0.25775224f, 0.99914503f, 0.9883351f, 0.122906685f, 0.09466505f, 0.12100351f, 0.49758863f, + 0.37254804f, 0.17272717f, 0.32066393f, 0.59446543f, 0.23875463f, 0.61079127f, 0.38534206f, 0.25771832f, + 0.56869274f, 0.9111291f, 0.16196036f, 0.5232172f, 0.31561613f, 0.99065316f, 0.025618374f, 0.0206694f, + 0.9926925f, 0.18365502f, 0.5958617f, 0.45684695f, 0.3946715f, 0.3883261f, 0.8177203f, 0.5238985f, + 0.013192713f, 0.20481992f, 0.32954985f, 0.7516082f, 0.17643315f, 0.9714598f, 0.38863534f, 0.410219f, + 0.891779f, 0.75130385f, 0.92406017f, 0.7892222f, 0.34832305f, 0.1682638f, 0.46279848f, 0.9138188f, + 0.3321901f, 0.036315024f, 0.7049642f, 0.9867357f, 0.3576584f, 0.08598822f, 0.046470165f, 0.6252997f, + 0.46214014f, 0.24750638f, 0.60106593f, 0.6898794f, 0.8976595f, 0.8881911f, 0.42515814f, 0.059116423f, + 0.048188448f, 0.9668448f, 0.7210276f, 0.7179537f, 0.06738949f, 0.96300787f, 0.97367156f, 0.95143014f, + 0.07820749f, 0.3113383f, 0.1561181f, 0.9734828f, 0.28516f, 0.27172273f, 0.76195645f, 0.26870382f, + 0.25373894f, 0.45626426f, 0.45194024f, 0.11051077f, 0.91683406f, 0.27943915f, 0.67735744f, 0.9348918f, + 0.7521582f, 0.57078993f, 0.9254285f, 0.5672131f, 0.2686717f, 0.97299975f, 0.61834025f, 0.012159586f, + 0.3576542f, 0.15941626f, 0.9383765f, 0.41742706f, 0.044237554f, 0.46856833f, 0.81400645f, 0.6299002f, + 0.6581022f, 0.5464366f, 0.68640935f, 0.378174f, 0.3010999f, 0.032645762f, 0.12333155f, 0.71670127f, + 0.20394331f, 0.57173324f, 0.6595957f, 0.53540194f, 0.17582512f, 0.9781642f, 0.20925027f, 0.9112503f, + 0.10224587f, 0.37972575f, 0.7719844f, 0.29570967f, 0.9200215f, 0.15592176f, 0.080114245f, 0.27454042f, + 0.5808252f, 0.96037793f, 0.26129955f, 0.6788141f, 0.37464648f, 0.39156884f, 0.8676517f, 0.112507045f, + 0.55310667f, 0.9702046f, 0.4312939f, 0.88821906f, 0.3460216f, 0.9024811f, 0.016334832f, 0.42793816f, + 0.4121768f, 0.6620425f, 0.6961637f, 0.88390845f, 0.425507f, 0.48017246f, 0.8424056f, 0.36471343f, + 0.9383168f, 0.16709393f, 0.44589508f, 0.47314453f, 0.72310495f, 0.84183806f, 0.4207481f, 0.0857597f, + 0.7477461f, 0.6495659f, 0.70084965f, 0.19156617f, 0.8217978f, 0.9735775f, 0.5433857f, 0.032975793f, + 0.85099494f, 0.12927437f, 0.61493605f, 0.5726589f, 0.26598173f, 0.6740978f, 0.052783668f, 0.61387974f}; + const std::vector fc2_experts_weights = { + 0.18302453f, 0.44593316f, 0.5643144f, 0.9259722f, 0.26143986f, 0.82031804f, 0.4364831f, 0.2625361f, + 0.06460017f, 0.04124081f, 0.98830533f, 0.37530023f, 0.5249744f, 0.63555616f, 0.8398661f, 0.92673707f, + 0.9055086f, 0.12955844f, 0.4198916f, 0.20413119f, 0.21432412f, 0.6186035f, 0.969324f, 0.099448025f, + 0.80260223f, 0.24076664f, 0.40261286f, 0.89688545f, 0.38691485f, 0.5455279f, 0.15048373f, 0.92562044f, + 0.43536508f, 0.13430476f, 0.64640516f, 0.14449131f, 0.10324633f, 0.5304596f, 0.8964218f, 0.358508f, + 0.73533344f, 0.9296606f, 0.83163047f, 0.23771948f, 0.44519007f, 0.34265757f, 0.09793854f, 0.5002066f, + 0.87621754f, 0.9212578f, 0.54665035f, 0.6135615f, 0.28353918f, 0.8774212f, 0.29194576f, 0.1526736f, + 0.57699674f, 0.7996927f, 0.04920423f, 0.95198375f, 0.67986554f, 0.14969361f, 0.39229625f, 0.93378997f, + 0.11638266f, 0.3538614f, 0.66399014f, 0.06195748f, 0.7740991f, 0.7602738f, 0.81010276f, 0.18122643f, + 0.9980005f, 0.20361924f, 0.99917024f, 0.020154774f, 0.054515004f, 0.80709815f, 0.55225646f, 0.52884465f, + 0.22312081f, 0.29026228f, 0.35380626f, 0.012922287f, 0.52598435f, 0.58842945f, 0.4995767f, 0.66146517f, + 0.9744255f, 0.632942f, 0.3169638f, 0.29422665f, 0.18009722f, 0.15339059f, 0.41947508f, 0.4115672f, + 0.72243124f, 0.2862816f, 0.89860183f, 0.14915991f, 0.5014211f, 0.94945997f, 0.99719256f, 0.21036887f, + 0.5890645f, 0.55906135f, 0.26557416f, 0.32725257f, 0.635427f, 0.1523174f, 0.58249784f, 0.71636236f, + 0.30296493f, 0.9153206f, 0.46709478f, 0.72685635f, 0.9951532f, 0.34716582f, 0.7717041f, 0.3569854f, + 0.4269635f, 0.41526443f, 0.4968937f, 0.3111158f, 0.61719346f, 0.5188402f, 0.8169449f, 0.39879733f, + 0.5501401f, 0.31400484f, 0.08127314f, 0.7023336f, 0.56397897f, 0.29975814f, 0.33094752f, 0.63076067f, + 0.40959156f, 0.82673794f, 0.52832156f, 0.68886834f, 0.7178481f, 0.37731683f, 0.71633244f, 0.86896664f, + 0.5230092f, 0.59784645f, 0.5181678f, 0.8461837f, 0.28890234f, 0.23421508f, 0.7178768f, 0.06484294f, + 0.5080162f, 0.27005446f, 0.8300168f, 0.034480453f, 0.8031663f, 0.9946784f, 0.60117006f, 0.46668667f, + 0.9921749f, 0.28632385f, 0.45993322f, 0.28104752f, 0.43097937f, 0.60866946f, 0.5667807f, 0.40556252f, + 7.969141e-05f, 0.52560204f, 0.48518902f, 0.5752184f, 0.8831251f, 0.9860047f, 0.20335877f, 0.46882278f, + 0.2996632f, 0.03917718f, 0.13617045f, 0.96928054f, 0.79153055f, 0.76857555f, 0.7778716f, 0.102760494f, + 0.5525096f, 0.9653573f, 0.22095704f, 0.94479716f, 0.63141924f, 0.8517718f, 0.28580618f, 0.73050886f, + 0.05675614f, 0.46825224f, 0.6667756f, 0.6499472f, 0.91840404f, 0.99132854f, 0.9548785f, 0.8356961f, + 0.851531f, 0.43548512f, 0.111976564f, 0.31438643f, 0.44386774f, 0.22980672f, 0.75558543f, 0.6755136f, + 0.58067596f, 0.62078035f, 0.93922615f, 0.6821157f, 0.061530292f, 0.13705963f, 0.7203748f, 0.5681396f, + 0.7438458f, 0.0006400347f, 0.038565338f, 0.8066132f, 0.81982285f, 0.047644496f, 0.68979263f, 0.109577894f, + 0.8786539f, 0.6568952f, 0.99439347f, 0.0070040226f, 0.018661916f, 0.838051f, 0.94391155f, 0.80634f, + 0.8324149f, 0.078864336f, 0.8619068f, 0.027926445f, 0.61170083f, 0.17248261f, 0.30140227f, 0.5885344f, + 0.30341f, 0.42088854f, 0.02608782f, 0.02856338f, 0.69368154f, 0.28836077f, 0.19580519f, 0.30270886f, + 0.09121573f, 0.100299895f, 0.79918617f, 0.75412107f, 0.56660175f, 0.22687018f, 0.6663505f, 0.5224626f, + 0.1426636f, 0.6075949f, 0.95527196f, 0.008196831f, 0.0028039217f, 0.5640625f, 0.87651116f, 0.19575512f, + 0.61006856f, 0.85149264f, 0.6541582f, 0.6082054f, 0.998863f, 0.82573634f, 0.21878648f, 0.54321826f, + 0.7554362f, 0.94095474f, 0.002533555f, 0.77075267f, 0.35483408f, 0.010389388f, 0.610987f, 0.22779316f, + 0.5708561f, 0.17537653f, 0.12373549f, 0.4575745f, 0.33203715f, 0.79243237f, 0.54310906f, 0.8902793f, + 0.5937015f, 0.33921933f, 0.8386668f, 0.52732253f, 0.59384584f, 0.3391887f, 0.5017944f, 0.40386343f, + 0.45749134f, 0.110060334f, 0.49692506f, 0.084977865f, 0.3924346f, 0.7897731f, 0.15232486f, 0.16297412f, + 0.37791175f, 0.36293298f, 0.5846437f, 0.5830078f, 0.75354826f, 0.15555972f, 0.4647144f, 0.7796456f, + 0.93248576f, 0.46352726f, 0.2106899f, 0.6437313f, 0.78473866f, 0.18762505f, 0.20985329f, 0.7209991f, + 0.464967f, 0.02775067f, 0.21170747f, 0.7027664f, 0.33041215f, 0.8451145f, 0.89526993f, 0.57273495f, + 0.46046263f, 0.34128642f, 0.47471708f, 0.59101045f, 0.11807448f, 0.38050216f, 0.08409953f, 0.80687743f, + 0.18158185f, 0.9567719f, 0.3711096f, 0.21356237f, 0.74022657f, 0.57453954f, 0.846228f, 0.70873487f, + 0.018330276f, 0.8162452f, 0.40584308f, 0.27901447f, 0.81752694f, 0.86466515f, 0.060534656f, 0.45478833f, + 0.9106033f, 0.6936434f, 0.92123467f, 0.32865065f, 0.22417879f, 0.9299548f, 0.70841146f, 0.97999126f, + 0.2911517f, 0.17896658f, 0.44139355f, 0.029210031f, 0.6959876f, 0.8687942f, 0.62002844f, 0.45059657f, + 0.74790317f, 0.18262434f, 0.98912156f, 0.0028281808f, 0.021027386f, 0.38184917f, 0.90842223f, 0.5500629f, + 0.69202286f, 0.13349658f, 0.6823429f, 0.44412827f, 0.7004118f, 0.8531213f, 0.7173401f, 0.4574679f, + 0.46920043f, 0.18640989f, 0.31914896f, 0.82491904f, 0.29950172f, 0.8105199f, 0.30173403f, 0.38355058f, + 0.5106411f, 0.04116726f, 0.49500751f, 0.44960213f, 0.45508182f, 0.4000479f, 0.89418864f, 0.8689936f, + 0.16112137f, 0.7322634f, 0.10780871f, 0.07433933f, 0.652841f, 0.50734824f, 0.26674682f, 0.017748117f, + 0.30643195f, 0.66699976f, 0.03719926f, 0.014267266f, 0.56343627f, 0.13979793f, 0.061959863f, 0.3073569f, + 0.41949958f, 0.045647383f, 0.16613615f, 0.5327839f, 0.028514147f, 0.4297228f, 0.17714864f, 0.15338135f, + 0.6965155f, 0.11515516f, 0.1210829f, 0.78514075f, 0.59348315f, 0.9553564f, 0.36635226f, 0.25849247f, + 0.45372677f, 0.5025297f, 0.88132215f, 0.0019600391f, 0.46439964f, 0.7211761f, 0.22465849f, 0.2459296f, + 0.7416339f, 0.020907402f, 0.6184779f, 0.112906754f, 0.7485309f, 0.072479784f, 0.8074024f, 0.026683688f, + 0.07971662f, 0.50736845f, 0.8939942f, 0.0718022f, 0.27697015f, 0.9391413f, 0.4161513f, 0.7071423f, + 0.019000888f, 0.34275955f, 0.24608392f, 0.9215306f, 0.70751995f, 0.13516217f, 0.5806135f, 0.49425328f, + 0.29456508f, 0.21446168f, 0.3340807f, 0.89411324f, 0.14157385f, 0.14382833f, 0.34574044f, 0.50869817f, + 0.63610595f, 0.51500404f, 0.37963718f, 0.19682491f, 0.41028368f, 0.29872334f, 0.9039644f, 0.013295233f, + 0.1810705f, 0.093204916f, 0.4086216f, 0.8896367f, 0.9382696f, 0.06472236f, 0.47833657f, 0.7934831f, + 0.7203987f, 0.9095519f, 0.4861309f, 0.16405362f, 0.83076525f, 0.3285427f, 0.7588931f, 0.37678176f, + 0.71254706f, 0.949713f, 0.96492773f, 0.044967473f, 0.16925985f, 0.2932666f, 0.18114948f, 0.97975004f, + 0.4558406f, 0.16832972f, 0.27750528f, 0.2238177f, 0.7039947f, 0.06387442f, 0.033798456f, 0.007119417f}; + const std::vector fc1_experts_bias = { + 0.71526206f, 0.7472273f, 0.18946046f, 0.6239893f, 0.86909235f, 0.5726507f, 0.3942092f, 0.5369412f, + 0.44638616f, 0.7517496f, 0.16049433f, 0.75355124f, 0.7818118f, 0.19706267f, 0.9082818f, 0.9910924f, + 0.30288565f, 0.3599528f, 0.74917775f, 0.10828978f, 0.697729f, 0.61665237f, 0.81516486f, 0.0656966f, + 0.0846076f, 0.72456455f, 0.6801054f, 0.034616888f, 0.22117025f, 0.042510748f, 0.14178854f, 0.27440017f, + 0.91376925f, 0.40047455f, 0.7871756f, 0.97484046f, 0.7278661f, 0.052394807f, 0.75161135f, 0.6907173f, + 0.8875328f, 0.0067828894f, 0.807508f, 0.9092707f, 0.034817636f, 0.55231315f, 0.92683655f, 0.13634592f, + 0.66405964f, 0.7209387f, 0.63104504f, 0.9971379f, 0.9093898f, 0.9289774f, 0.4376766f, 0.9193563f, + 0.03404367f, 0.23018533f, 0.39305943f, 0.3514716f, 0.96184736f, 0.73583263f, 0.8219065f, 0.8401047f}; + const std::vector fc2_experts_bias = { + 0.12649822f, 0.4420895f, 0.5730123f, 0.63004625f, 0.7571163f, 0.3010466f, 0.3492328f, 0.91837066f, + 0.36580783f, 0.15267932f, 0.8390199f, 0.83857775f, 0.34321654f, 0.40003997f, 0.13106f, 0.08245313f, + 0.68802476f, 0.28640372f, 0.89804775f, 0.09964341f, 0.43088746f, 0.5107959f, 0.75697356f, 0.90466535f, + 0.83860224f, 0.720098f, 0.2705031f, 0.14292616f, 0.052693605f, 0.5248023f, 0.9849401f, 0.40502876f}; + const std::vector output = { + 0.2552814f, 0.17651685f, 0.0034551744f, -0.123282805f, 0.0073816925f, 0.004265253f, 0.16927283f, -0.05276826f, + 9.555821f, 7.6907287f, 10.626425f, 7.0543795f, 8.10093f, 10.3664465f, 10.925815f, 8.737018f, + 0.565234f, 0.17098689f, 0.10810414f, 0.43916586f, 0.3535297f, 0.45673048f, 0.3853893f, 0.18613164f, + 1.3354061f, 0.5049282f, 0.72775036f, 0.90331376f, 1.2945517f, 0.9123066f, 1.1995136f, 0.7708638f}; + + RunMoETest(input, + router_probs, + fc1_experts_weights, + fc2_experts_weights, + fc1_experts_bias, + fc2_experts_bias, + output, + num_rows, + num_experts, + hidden_size, + inter_size, + "gelu"); +} + +TEST(MoETest, MoETest_Relu) { + int num_rows = 4; + int num_experts = 4; + int hidden_size = 8; + int inter_size = 16; + + const std::vector input = { + 0.7670296f, -0.93721074f, -2.330477f, -0.78088343f, 0.8250065f, 1.2206652f, -0.06297584f, 1.1463639f, + 1.2215378f, -0.31372663f, -0.7234253f, -0.3627346f, 0.44249064f, 0.19418247f, -0.49998695f, -0.55005103f, + 0.023851749f, -1.5203826f, 0.52939993f, -0.39082858f, -1.9291036f, 0.034976702f, -0.48336256f, -1.226073f, + -0.33963847f, 0.0073261578f, -0.0521804f, 1.16749f, 1.7302082f, 2.0561688f, -0.2347232f, -1.3456243f}; + const std::vector router_probs = { + -0.08146476f, -0.40439552f, 1.0100367f, -0.7724162f, -0.08113786f, -0.36328858f, 0.3688482f, -0.013465762f, + -0.32420647f, -0.3815508f, 0.79585606f, 0.14430691f, -0.21869831f, 0.11483674f, -0.11992836f, 0.35216537f}; + const std::vector fc1_experts_weights = { + 0.81960344f, 0.9296998f, 0.45050132f, 0.38805157f, 0.50729614f, 0.47014588f, 0.62020564f, 0.6401168f, + 0.045871615f, 0.31548113f, 0.92106473f, 0.6947775f, 0.4751312f, 0.19854712f, 0.19409746f, 0.052116573f, + 0.3370188f, 0.6688521f, 0.8188108f, 0.73084867f, 0.058027983f, 0.19931877f, 0.42109168f, 0.98367476f, + 0.57232875f, 0.37051463f, 0.7068576f, 0.30955923f, 0.17637217f, 0.8649436f, 0.2726491f, 0.39976662f, + 0.0025978684f, 0.8346353f, 0.8788173f, 0.6822241f, 0.1513629f, 0.0065300465f, 0.093910515f, 0.8728501f, + 0.7400529f, 0.9207522f, 0.76193494f, 0.6265461f, 0.49510366f, 0.11974698f, 0.07161391f, 0.032325685f, + 0.704681f, 0.254516f, 0.3993737f, 0.21224737f, 0.40888822f, 0.14808255f, 0.17329216f, 0.6658554f, + 0.3514018f, 0.8086716f, 0.33959562f, 0.13321638f, 0.41178054f, 0.2576263f, 0.3470292f, 0.024002194f, + 0.77974546f, 0.15189773f, 0.75130886f, 0.7268921f, 0.85721636f, 0.11647397f, 0.8595984f, 0.2636242f, + 0.6855346f, 0.96955734f, 0.42948407f, 0.49613327f, 0.38488472f, 0.08250773f, 0.73995143f, 0.003641069f, + 0.81039995f, 0.87411255f, 0.9728532f, 0.38206023f, 0.08917904f, 0.61241513f, 0.77621365f, 0.0023456216f, + 0.38650817f, 0.20027226f, 0.45626813f, 0.25389326f, 0.2956162f, 0.34127057f, 0.024847984f, 0.91025376f, + 0.9191656f, 0.42156547f, 0.44305897f, 0.29594004f, 0.04846859f, 0.013427794f, 0.6858292f, 0.22547692f, + 0.17856151f, 0.4609884f, 0.33349442f, 0.3382396f, 0.5160656f, 0.3939438f, 0.3278438f, 0.26059705f, + 0.0930863f, 0.9192536f, 0.29990643f, 0.63248974f, 0.32651705f, 0.54063064f, 0.9661502f, 0.73036134f, + 0.06670016f, 0.6984514f, 0.9746214f, 0.63154167f, 0.83521235f, 0.99294376f, 0.4233855f, 0.6037772f, + 0.15248245f, 0.39696145f, 0.8702919f, 0.7563229f, 0.18360549f, 0.099057496f, 0.15831816f, 0.00656116f, + 0.114180505f, 0.3763513f, 0.8374386f, 0.5836911f, 0.11969727f, 0.09888804f, 0.74873763f, 0.12807935f, + 0.43843627f, 0.739853f, 0.26859397f, 0.44548005f, 0.45647776f, 0.38170832f, 0.24648392f, 0.054280818f, + 0.0958215f, 0.23226917f, 0.98291886f, 0.25849265f, 0.16423601f, 0.6211971f, 0.63780516f, 0.77395487f, + 0.8800602f, 0.7784371f, 0.004249513f, 0.5443443f, 0.80287653f, 0.45378727f, 0.20536041f, 0.9766699f, + 0.31298608f, 0.21532774f, 0.04922247f, 0.52233416f, 0.72156656f, 0.6106814f, 0.59887487f, 0.12080628f, + 0.03305638f, 0.5088047f, 0.95591706f, 0.7884607f, 0.20888287f, 0.43509573f, 0.13140821f, 0.2587883f, + 0.5905492f, 0.77226925f, 0.91418463f, 0.04094696f, 0.8343076f, 0.14735395f, 0.6872336f, 0.92312264f, + 0.5070212f, 0.9549045f, 0.07397425f, 0.3090204f, 0.79162645f, 0.39106607f, 0.39764988f, 0.29160416f, + 0.84465307f, 0.7452516f, 0.66022503f, 0.21901816f, 0.09412521f, 0.5540803f, 0.6481394f, 0.26914406f, + 0.36010116f, 0.83768386f, 0.53982985f, 0.52255917f, 0.37694973f, 0.04720515f, 0.029871285f, 0.26099247f, + 0.2458393f, 0.6557768f, 0.35444462f, 0.30438894f, 0.9767149f, 0.67416143f, 0.85645115f, 0.25794363f, + 0.2957666f, 0.68377024f, 0.16686243f, 0.17314798f, 0.47585016f, 0.31711966f, 0.125171f, 0.7965795f, + 0.90208143f, 0.58111167f, 0.41294336f, 0.036863506f, 0.31788063f, 0.6272928f, 0.73576546f, 0.43679124f, + 0.30232358f, 0.77861303f, 0.10180014f, 0.816009f, 0.30602258f, 0.5076527f, 0.40119207f, 0.5606195f, + 0.3489008f, 0.8635635f, 0.48700142f, 0.89029974f, 0.98074025f, 0.25640452f, 0.13524544f, 0.901151f, + 0.89180696f, 0.11822635f, 0.46134835f, 0.006936848f, 0.09070045f, 0.59657127f, 0.6330173f, 0.6059905f, + 0.36391765f, 0.96128887f, 0.571489f, 0.2049576f, 0.4716931f, 0.6200726f, 0.67509633f, 0.14645958f, + 0.6873948f, 0.24455917f, 0.08452982f, 0.22689629f, 0.9822047f, 0.9274289f, 0.9477422f, 0.7935056f, + 0.87772477f, 0.43307513f, 0.22488606f, 0.7498283f, 0.24090862f, 0.16256708f, 0.34033298f, 0.6049296f, + 0.7573983f, 0.3057955f, 0.20571685f, 0.56744653f, 0.2052834f, 0.17446929f, 0.76062596f, 0.4160077f, + 0.9568925f, 0.9863913f, 0.64955276f, 0.67207885f, 0.61514187f, 0.50783044f, 0.46363378f, 0.50687206f, + 0.6867124f, 0.9648854f, 0.37042046f, 0.2886421f, 0.37891757f, 0.25843787f, 0.58501935f, 0.8732242f, + 0.8909887f, 0.72956276f, 0.13203424f, 0.23164761f, 0.3901443f, 0.40783793f, 0.54112387f, 0.041014254f, + 0.65562236f, 0.11856395f, 0.18362767f, 0.08430874f, 0.9356598f, 0.026530087f, 0.8771834f, 0.48319155f, + 0.4418506f, 0.81273925f, 0.4537862f, 0.81357706f, 0.8615075f, 0.06589496f, 0.692392f, 0.5943895f, + 0.60750586f, 0.5729957f, 0.6367655f, 0.2594666f, 0.43602943f, 0.97506f, 0.83592474f, 0.48121578f, + 0.029734552f, 0.5219139f, 0.15951324f, 0.90659577f, 0.19645631f, 0.4638992f, 0.38902867f, 0.5889769f, + 0.9705138f, 0.5475096f, 0.789582f, 0.8881108f, 0.9036556f, 0.32732427f, 0.38817167f, 0.7409689f, + 0.36356616f, 0.734132f, 0.39076614f, 0.16087383f, 0.70352167f, 0.576659f, 0.7229242f, 0.996743f, + 0.84136647f, 0.97399056f, 0.5267614f, 0.06989372f, 0.14923638f, 0.18941313f, 0.059375823f, 0.24937624f, + 0.039716125f, 0.038692355f, 0.20122272f, 0.0070830584f, 0.19309378f, 0.69065434f, 0.9170264f, 0.3512686f, + 0.3545606f, 0.76697665f, 0.25331455f, 0.26358372f, 0.80806476f, 0.064349174f, 0.5611374f, 0.941691f, + 0.58574325f, 0.6359719f, 0.20880443f, 0.49310172f, 0.5274922f, 0.62271714f, 0.694273f, 0.9344639f, + 0.11835027f, 0.51498765f, 0.25018185f, 0.10446805f, 0.45996118f, 0.059881568f, 0.8489496f, 0.5579074f, + 0.23052096f, 0.76128954f, 0.02678603f, 0.3066004f, 0.40259063f, 0.07512486f, 0.18205583f, 0.4183907f, + 0.8793823f, 0.9828271f, 0.8181312f, 0.20143801f, 0.17288941f, 0.9363466f, 0.6768587f, 0.51328385f, + 0.56766605f, 0.098151624f, 0.33305728f, 0.98130906f, 0.3766839f, 0.47491795f, 0.08483446f, 0.22029644f, + 0.4897902f, 0.18942028f, 0.4379952f, 0.7034796f, 0.0109113455f, 0.64850605f, 0.16939592f, 0.25597447f, + 0.69195485f, 0.8975601f, 0.36334568f, 0.29471546f, 0.04788208f, 0.24217117f, 0.062181532f, 0.38556474f, + 0.6020277f, 0.03156215f, 0.93655676f, 0.81369543f, 0.010527074f, 0.2611835f, 0.6630776f, 0.3972702f, + 0.44551176f, 0.27424216f, 0.9016098f, 0.22050089f, 0.9146384f, 0.53226113f, 0.6005109f, 0.8900659f, + 0.4176172f, 0.21532834f, 0.4191329f, 0.9055267f, 0.12900633f, 0.6134902f, 0.008604288f, 0.76215106f, + 0.68473387f, 0.5211961f, 0.71459657f, 0.50056237f, 0.7766764f, 0.10418975f, 0.42657375f, 0.7218073f, + 0.9979084f, 0.7546957f, 0.1364128f, 0.8845484f, 0.38850087f, 0.39324278f, 0.04554516f, 0.42129284f, + 0.8536634f, 0.5697224f, 0.20877302f, 0.65390605f, 0.3396778f, 0.956497f, 0.066022694f, 0.34206223f, + 0.017213225f, 0.3030849f, 0.6576238f, 0.9813073f, 0.58397317f, 0.99017924f, 0.59782606f, 0.788768f, + 0.9008311f, 0.91796166f, 0.22013813f, 0.959695f, 0.80288273f, 0.2662105f, 0.26139832f, 0.080626905f}; + const std::vector fc2_experts_weights = { + 0.6255686f, 0.09472537f, 0.71121234f, 0.65789884f, 0.065598905f, 0.63625044f, 0.45933473f, 0.7284089f, + 0.7868948f, 0.0029274821f, 0.95854944f, 0.919321f, 0.6989418f, 0.043019474f, 0.32138962f, 0.35509557f, + 0.37150103f, 0.78196156f, 0.6817853f, 0.89608955f, 0.31273842f, 0.6682699f, 0.6778976f, 0.08370459f, + 0.014990091f, 0.24055547f, 0.84227383f, 0.029270172f, 0.0647831f, 0.7801003f, 0.7697645f, 0.91119635f, + 0.12253064f, 0.13405013f, 0.75649333f, 0.9348151f, 0.7991694f, 0.57832605f, 0.66478735f, 0.97456336f, + 0.17739785f, 0.2729941f, 0.8497335f, 0.15788019f, 0.22429371f, 0.86499554f, 0.65776104f, 0.661535f, + 0.2880798f, 0.49309975f, 0.9576164f, 0.19988996f, 0.5039311f, 0.73779976f, 0.15482187f, 0.98558843f, + 0.25019473f, 0.379932f, 0.36471486f, 0.17417055f, 0.009367704f, 0.7819258f, 0.63283706f, 0.031699598f, + 0.1781866f, 0.994184f, 0.6911175f, 0.7006223f, 0.20085096f, 0.28080195f, 0.42452294f, 0.40856004f, + 0.15737581f, 0.5411925f, 0.549694f, 0.4366895f, 0.5693159f, 0.3018247f, 0.63012594f, 0.6885702f, + 0.2366305f, 0.004210472f, 0.7617172f, 0.61926836f, 0.24570602f, 0.981851f, 0.273876f, 0.8378734f, + 0.75366426f, 0.080795944f, 0.82247066f, 0.040263534f, 0.22299266f, 0.41664255f, 0.16297674f, 0.98845494f, + 0.39971018f, 0.69859487f, 0.053544044f, 0.7878332f, 0.34460813f, 0.11966437f, 0.5731115f, 0.7422309f, + 0.93269855f, 0.19460368f, 0.25394785f, 0.59613144f, 0.6356306f, 0.6922361f, 0.7744376f, 0.38662314f, + 0.7777848f, 0.8686458f, 0.36938924f, 0.8557286f, 0.74428976f, 0.9410264f, 0.21586305f, 0.2530955f, + 0.35543054f, 0.52536315f, 0.8000995f, 0.21456867f, 0.750327f, 0.3208093f, 0.80205464f, 0.47626138f, + 0.061956525f, 0.22487706f, 0.13812399f, 0.74798125f, 0.1647259f, 0.45834088f, 0.6078779f, 0.22580266f, + 0.644235f, 0.011788309f, 0.14224577f, 0.0469383f, 0.34876132f, 0.3178513f, 0.5715967f, 0.40754277f, + 0.735041f, 0.9583977f, 0.67939556f, 0.30301625f, 0.031807184f, 0.68110096f, 0.25227106f, 0.75443816f, + 0.83424246f, 0.69286025f, 0.9691554f, 0.9748982f, 0.60586995f, 0.13568163f, 0.94672066f, 0.26275212f, + 0.2638232f, 0.9183893f, 0.88740516f, 0.65107566f, 0.5313419f, 0.07941705f, 0.44809794f, 0.9795632f, + 0.6273294f, 0.542809f, 0.3961745f, 0.32560885f, 0.79801136f, 0.53083426f, 0.8252871f, 0.4115007f, + 0.7184546f, 0.70638496f, 0.57973206f, 0.8141865f, 0.81332296f, 0.96346164f, 0.88438797f, 0.37215167f, + 0.0766899f, 0.5914087f, 0.49563587f, 0.3695873f, 0.41627264f, 0.5235164f, 0.86481494f, 0.6558706f, + 0.32245284f, 0.29438752f, 0.37618434f, 0.3067485f, 0.9496114f, 0.76482266f, 0.95148784f, 0.5015968f, + 0.60083544f, 0.67338234f, 0.026723444f, 0.5446483f, 0.466555f, 0.21967298f, 0.112026334f, 0.9426372f, + 0.906533f, 0.73173434f, 0.97712487f, 0.29709607f, 0.41363865f, 0.6893093f, 0.4173867f, 0.4018826f, + 0.086719275f, 0.63433063f, 0.1978364f, 0.5181831f, 0.9874878f, 0.34609234f, 0.34240413f, 0.8016564f, + 0.31617337f, 0.4570613f, 0.96686924f, 0.29501313f, 0.14229488f, 0.22017813f, 0.36137718f, 0.26275063f, + 0.24053413f, 0.70197225f, 0.58496886f, 0.33996922f, 0.11154431f, 0.34257007f, 0.28898042f, 0.33729053f, + 0.048938513f, 0.60771453f, 0.13263822f, 0.11060041f, 0.091483414f, 0.70869184f, 0.19898665f, 0.29362458f, + 0.8919203f, 0.7654821f, 0.7866956f, 0.02524674f, 0.1414501f, 0.3112445f, 0.9130488f, 0.5511502f, + 0.12605143f, 0.5031309f, 0.11166459f, 0.39045036f, 0.36251247f, 0.9328308f, 0.65486836f, 0.41281444f, + 0.5844644f, 0.35566723f, 0.6964502f, 0.6977819f, 0.63427305f, 0.30511153f, 0.92657536f, 0.42781502f, + 0.30534166f, 0.813157f, 0.90752834f, 0.9975799f, 0.64812917f, 0.32955307f, 0.753946f, 0.92897725f, + 0.009582937f, 0.43805653f, 0.15901726f, 0.5931799f, 0.7067924f, 0.39670604f, 0.45817143f, 0.7250554f, + 0.41596514f, 0.08011025f, 0.900068f, 0.24834275f, 0.44507074f, 0.5471632f, 0.46995157f, 0.029657006f, + 0.7294f, 0.27288425f, 0.2406702f, 0.6194577f, 0.23906898f, 0.26892018f, 0.33152503f, 0.3121612f, + 0.29118127f, 0.36515707f, 0.6299379f, 0.095391035f, 0.19735986f, 0.5072957f, 0.56953406f, 0.77614623f, + 0.14877802f, 0.65959847f, 0.7841949f, 0.7776301f, 0.03428924f, 0.3091979f, 0.07021719f, 0.18359429f, + 0.77849144f, 0.42534047f, 0.7123557f, 0.20649683f, 0.57597995f, 0.19757104f, 0.749946f, 0.2813105f, + 0.37462044f, 0.06618434f, 0.50165176f, 0.9747401f, 0.7426891f, 0.23322952f, 0.50672436f, 0.44517577f, + 0.09746289f, 0.89204556f, 0.50806034f, 0.6052985f, 0.2980855f, 0.26604044f, 0.5824448f, 0.68485546f, + 0.612149f, 0.25902748f, 0.9854489f, 0.4263978f, 0.19379246f, 0.26614368f, 0.9922104f, 0.5000241f, + 0.4321279f, 0.2919191f, 0.3689273f, 0.078885734f, 0.10265827f, 0.79264474f, 0.9277247f, 0.9771502f, + 0.13902885f, 0.77043164f, 0.19051671f, 0.7982801f, 0.86077714f, 0.8869355f, 0.86002564f, 0.81278664f, + 0.5097318f, 0.7297412f, 0.32111454f, 0.7177174f, 0.33929902f, 0.49160433f, 0.064810574f, 0.3692627f, + 0.23706353f, 0.3313396f, 0.18070674f, 0.05027789f, 0.53255826f, 0.8244896f, 0.9553747f, 0.7917771f, + 0.24083132f, 0.005495131f, 0.6896569f, 0.78015697f, 0.07074398f, 0.67929304f, 0.9227386f, 0.5302883f, + 0.19877058f, 0.90993816f, 0.71350795f, 0.8311006f, 0.16185725f, 0.79097277f, 0.15846318f, 0.99474716f, + 0.28815013f, 0.80128354f, 0.6001208f, 0.63250524f, 0.4233225f, 0.7053677f, 0.29161406f, 0.028710365f, + 0.30789846f, 0.8917693f, 0.36836517f, 0.6571592f, 0.3151368f, 0.8750746f, 0.7992451f, 0.6765068f, + 0.24441916f, 0.091435075f, 0.5188247f, 0.20667112f, 0.9110969f, 0.019512117f, 0.72343415f, 0.998457f, + 0.7504142f, 0.6704894f, 0.01892668f, 0.9809466f, 0.41447622f, 0.032795787f, 0.9935814f, 0.29653466f, + 0.4646262f, 0.95763975f, 0.15339965f, 0.14625502f, 0.58130866f, 0.43307304f, 0.6151709f, 0.08064735f, + 0.5149533f, 0.27762014f, 0.25419557f, 0.04218155f, 0.7651092f, 0.59631824f, 0.077278376f, 0.89677596f, + 0.6508104f, 0.5927816f, 0.2064318f, 0.57540226f, 0.9817701f, 0.84294224f, 0.11056489f, 0.9564106f, + 0.5387549f, 0.74048257f, 0.88833815f, 0.9262546f, 0.11023259f, 0.93783194f, 0.16041255f, 0.53748304f, + 0.1506182f, 0.39038336f, 0.47727865f, 0.44018233f, 0.42101204f, 0.53943527f, 0.99320936f, 0.79050577f, + 0.77973497f, 0.7001237f, 0.88709056f, 0.4769255f, 0.5397561f, 0.60289854f, 0.06393474f, 0.09722155f, + 0.5613007f, 0.30437487f, 0.49082512f, 0.3852706f, 0.5778314f, 0.8253078f, 0.33417904f, 0.9004303f, + 0.8947809f, 0.11625093f, 0.11388689f, 0.09546256f, 0.22598988f, 0.30536187f, 0.46236527f, 0.3784039f, + 0.24737573f, 0.3411532f, 0.31912774f, 0.9905191f, 0.31468558f, 0.14199954f, 0.7078488f, 0.47111923f, + 0.882782f, 0.8124163f, 0.9593644f, 0.13382024f, 0.8214317f, 0.9196194f, 0.25308424f, 0.95958996f}; + const std::vector fc1_experts_bias = { + 0.8748215f, 0.5054756f, 0.74107623f, 0.32518923f, 0.0639081f, 0.62639004f, 0.64906263f, 0.17322052f, + 0.7424998f, 0.07288867f, 0.93031204f, 0.9841952f, 0.6361292f, 0.18628561f, 0.7433356f, 0.5852079f, + 0.6359594f, 0.66432667f, 0.88067776f, 0.28508204f, 0.38752747f, 0.63635296f, 0.55448055f, 0.9031888f, + 0.23738074f, 0.48179168f, 0.5934266f, 0.3672055f, 0.84085834f, 0.5546908f, 0.03788501f, 0.44583207f, + 0.27322155f, 0.5485856f, 0.44189203f, 0.00403291f, 0.40888733f, 0.45211035f, 0.35256076f, 0.9593902f, + 0.39090043f, 0.8212086f, 0.62385887f, 0.07793343f, 0.61749303f, 0.9143678f, 0.17294967f, 0.17681253f, + 0.9894245f, 0.901755f, 0.221053f, 0.8008725f, 0.43603396f, 0.007035315f, 0.5375667f, 0.661547f, + 0.35001957f, 0.67394173f, 0.072449565f, 0.84650797f, 0.92626715f, 0.77573335f, 0.58474565f, 0.66467446f}; + const std::vector fc2_experts_bias = { + 0.13822609f, 0.3750633f, 0.45226622f, 0.22175694f, 0.13068998f, 0.8363088f, 0.8393226f, 0.045905888f, + 0.65910596f, 0.7034011f, 0.97498417f, 0.78927684f, 0.95966834f, 0.33630514f, 0.8501932f, 0.9067007f, + 0.027835965f, 0.09864664f, 0.6012027f, 0.7730189f, 0.25159347f, 0.55506724f, 0.49927413f, 0.62655383f, + 0.23132521f, 0.7820195f, 0.8325047f, 0.15307087f, 0.5048437f, 0.5013873f, 0.66055787f, 0.96579224f}; + const std::vector output = { + 1.3775184f, 2.0985768f, 2.091839f, 2.9706357f, 1.9404914f, 1.9915576f, 2.3302228f, 2.3702593f, + 0.51896286f, 0.7936432f, 0.9944805f, 1.3225251f, 0.73894113f, 0.87975955f, 1.0468717f, 1.1585085f, + 0.012911659f, 0.045757107f, 0.27884653f, 0.3585817f, 0.116771236f, 0.25755364f, 0.23161705f, 0.2906256f, + 4.8571277f, 5.649453f, 5.485141f, 5.306299f, 4.767025f, 6.9010167f, 5.3520975f, 6.711155f}; + + RunMoETest(input, + router_probs, + fc1_experts_weights, + fc2_experts_weights, + fc1_experts_bias, + fc2_experts_bias, + output, + num_rows, + num_experts, + hidden_size, + inter_size, + "relu"); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index f9a5c7618601c..9ab78cac3aca4 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -543,14 +543,12 @@ TEST(FunctionTest, TestInlinedLocalFunctionRemoved) { InferenceSessionWrapper session_object{session_options, GetEnvironment()}; std::stringstream sstr(serialized_model); - auto status = session_object.Load(sstr); - ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_STATUS_OK(session_object.Load(sstr)); auto model_proto = session_object.GetModel().ToProto(); ASSERT_EQ(1, model_proto.functions_size()); - status = session_object.Initialize(); - ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_STATUS_OK(session_object.Initialize()); // All functions removed model_proto = session_object.GetModel().ToProto(); @@ -591,5 +589,30 @@ TEST(FunctionTest, TestInlinedLocalFunctionNotRemoved) { #endif } +TEST(FunctionTest, TestInlinedFunctionDoesNotReserrectNonExistingArgs) { + // Verify this runs + constexpr const ORTCHAR_T* model_uri = ORT_TSTR("testdata/transform/gh_issue_18338.onnx"); + + SessionOptions session_options; + InferenceSessionWrapper session_object{session_options, GetEnvironment()}; + + ASSERT_STATUS_OK(session_object.Load(model_uri)); + ASSERT_STATUS_OK(session_object.Initialize()); + + // Scalar shape for input_0 and output + const std::string input_names[] = {"input_0"}; + const std::string output_names[] = {"_val_3"}; + TensorShape input_shape; + MLFloat16 input_0_data{684.f}; + + OrtValue input_0; + Tensor::InitOrtValue(DataTypeImpl::GetType(), input_shape, &input_0_data, OrtMemoryInfo(), input_0); + + std::vector fetches(1); + RunOptions run_options; + ASSERT_STATUS_OK(session_object.Run(run_options, AsSpan(input_names), AsSpan({input_0}), + AsSpan(output_names), &fetches, 0)); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/mlas/bench/bench_q4gemm.cpp b/onnxruntime/test/mlas/bench/bench_q4gemm.cpp index cf02d4f3628f9..87e3601612761 100644 --- a/onnxruntime/test/mlas/bench/bench_q4gemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_q4gemm.cpp @@ -33,7 +33,7 @@ void Q4GEMM(benchmark::State& state, MLAS_BLK_QUANT_TYPE qtype) { auto B1 = RandomVectorUniform(static_cast(N * K), -1.0f, 1.0f); std::vector C1(static_cast(M * N)); - std::vector B1_packed(pack_b_size); + std::vector B1_packed(pack_b_size); MlasQ4GemmPackB(qtype, B1_packed.data(), B1.data(), N, K, N); MLAS_Q4_GEMM_DATA_PARAMS params1; diff --git a/onnxruntime/test/mlas/bench/bench_sgemm.cpp b/onnxruntime/test/mlas/bench/bench_sgemm.cpp index baa8f1a830ea1..e6e34bc88ad59 100644 --- a/onnxruntime/test/mlas/bench/bench_sgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sgemm.cpp @@ -128,7 +128,7 @@ BENCHMARK_CAPTURE(SGEMM, PACKB_TransA, true, true, false)->Apply(GemmSizeProduct static void GemmLLMSizeProducts(benchmark::internal::Benchmark* b) { b->ArgNames(sgemm_bench_arg_names); - ArgsProduct(b, {{1, 1024, 2048}, {4096}, {4096}}); + ArgsProduct(b, {{1, 1024, 2048}, {4096, 11008}, {4096, 11008}}); } BENCHMARK_CAPTURE(SGEMM, LLM, false, false, true)->Apply(GemmLLMSizeProducts)->UseRealTime(); diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp new file mode 100644 index 0000000000000..2f2635dab0512 --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "mlas_q4.h" +#include "mlas_qnbit.h" + +#include + +#include "benchmark/benchmark.h" + +#include "bench_util.h" +#include "core/util/thread_utils.h" + +template +void SQNBITGEMM(benchmark::State& state) { + if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); + if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); + if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); + if (state.range(3) <= 0) throw std::invalid_argument("Threads must greater than 0!"); + + const size_t M = static_cast(state.range(0)); + const size_t N = static_cast(state.range(1)); + const size_t K = static_cast(state.range(2)); + const size_t threads = static_cast(state.range(3)); + + size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; + MlasBlockwiseQuantizedBufferSizes( + BlkBitWidth, BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), + QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); + + OrtThreadPoolParams tpo; + tpo.thread_pool_size = static_cast(threads); + tpo.auto_set_affinity = true; + + std::unique_ptr tp( + onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), + tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); + + auto A = RandomVectorUniform(static_cast(M * K), -1.0f, 1.0f); + auto B = RandomVectorUniform(static_cast(K * N), -1.0f, 1.0f); + std::vector C(static_cast(M * N)); + + std::vector QuantBData(QuantBDataSizeInBytes); + std::vector QuantBScale(QuantBScaleSize); + std::vector QuantBZeroPoint(Symmetric ? 0 : QuantBZeroPointSizeInBytes); + + MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), + Symmetric ? nullptr : QuantBZeroPoint.data(), + B.data(), BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), static_cast(N), + tp.get()); + + MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; + params.A = A.data(); + params.lda = K; + params.QuantBData = QuantBData.data(); + params.QuantBScale = QuantBScale.data(); + params.QuantBZeroPoint = Symmetric ? nullptr : QuantBZeroPoint.data(); + params.Bias = nullptr; + params.C = C.data(); + params.ldc = N; + + // warm up run + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ¶ms, tp.get()); + + for (auto _ : state) { + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ¶ms, tp.get()); + } +} + +static void GemmSizeProducts(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K", "Threads"}); + ArgsProduct(b, {{1, 1024, 2048}, {4096, 11008}, {4096, 11008}, {8}}); +} + +BENCHMARK(SQNBITGEMM<4, 16, false>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 16, true>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 32, false>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 32, true>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 64, false>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 64, true>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 128, false>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 128, true>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 256, false>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 256, true>)->Apply(GemmSizeProducts)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_blockq4.cpp b/onnxruntime/test/mlas/unittest/test_blockq4.cpp index f836da8277bb8..07f0748fb7ed1 100644 --- a/onnxruntime/test/mlas/unittest/test_blockq4.cpp +++ b/onnxruntime/test/mlas/unittest/test_blockq4.cpp @@ -38,13 +38,17 @@ class MlasBlockwiseQdqTest : public MlasTestBase { int meta_rows; int meta_cols; - MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); + MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); int q_rows; int q_cols; - MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); + MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); - uint8_t* elements = InputElements.GetBuffer(q_rows * q_cols, true); + size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; + MlasBlockwiseQuantizedBufferSizes(4, block_size, columnwise, rows, columns, + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); + + uint8_t* elements = InputElements.GetBuffer(q_data_size_in_bytes, true); int v = 7; for (int c = 0; c < columns; c++) { @@ -70,8 +74,8 @@ class MlasBlockwiseQdqTest : public MlasTestBase { } } - float* scales = InputScales.GetBuffer(meta_rows * meta_cols); - uint8_t* zp = symmetric ? nullptr : InputOffsets.GetBuffer(((meta_rows + 1) / 2) * meta_cols, true); + float* scales = InputScales.GetBuffer(q_scale_size); + uint8_t* zp = symmetric ? nullptr : InputOffsets.GetBuffer(q_zp_size_in_bytes, true); if (zp) { for (int c = 0; c < meta_cols; c++) { for (int r = 0; r < meta_rows; r += 2) { diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.h b/onnxruntime/test/mlas/unittest/test_halfgemm.h index 2861b0e746fdc..4db5c2bebca40 100644 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.h +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.h @@ -18,20 +18,6 @@ Module Name: #include "test_fp16.h" -inline bool -CloseEnough(float actual, float expected) { - if (std::isnan(actual)) { - return std::isnan(expected); - } - float diff = std::abs(actual - expected); - float top = std::max(std::abs(actual), std::abs(expected)); - float ratio = 0; - if (top > 0.0001) { - ratio = diff / top; - } - return ratio < 0.005; -} - /** * @brief Test class for half precision GEMM * @tparam AType Data type of A matrix, can be either float or MLFp16 diff --git a/onnxruntime/test/mlas/unittest/test_main.cpp b/onnxruntime/test/mlas/unittest/test_main.cpp index 66b5a6a15db2b..505c0c01dfa90 100644 --- a/onnxruntime/test/mlas/unittest/test_main.cpp +++ b/onnxruntime/test/mlas/unittest/test_main.cpp @@ -1,17 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "test_util.h" - -#include #include +#include +#include + +#include "test_util.h" #if !defined(BUILD_MLAS_NO_ONNXRUNTIME) MLAS_THREADPOOL* GetMlasThreadPool(void) { - static MLAS_THREADPOOL* threadpool = new onnxruntime::concurrency::ThreadPool( + static auto threadpool = std::make_unique( &onnxruntime::Env::Default(), onnxruntime::ThreadOptions(), nullptr, 2, true); - return threadpool; + return threadpool.get(); } #else diff --git a/onnxruntime/test/mlas/unittest/test_q4gemm.h b/onnxruntime/test/mlas/unittest/test_q4gemm.h index 58a64491ae80b..97c6969b5bf91 100644 --- a/onnxruntime/test/mlas/unittest/test_q4gemm.h +++ b/onnxruntime/test/mlas/unittest/test_q4gemm.h @@ -19,20 +19,6 @@ Module Name: #include "test_util.h" #include "mlas_q4.h" -inline bool -CloseEnough(float actual, float expected) { - if (std::isnan(actual)) { - return std::isnan(expected); - } - float diff = std::abs(actual - expected); - float top = std::max(std::abs(actual), std::abs(expected)); - float ratio = 0; - if (top > 0.0001) { - ratio = diff / top; - } - return ratio < 0.005; -} - /** * @brief Test class for int4 block quantized GEMM * Note: only 2-D matmul supported for now diff --git a/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp b/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp index a78a3261d1f2a..d3f601793a970 100644 --- a/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp @@ -19,20 +19,6 @@ Module Name: #include "test_util.h" #include "mlas_q4.h" -inline bool -CloseEnough(float actual, float expected) { - if (std::isnan(actual)) { - return std::isnan(expected); - } - float diff = std::abs(actual - expected); - float top = std::max(std::abs(actual), std::abs(expected)); - float ratio = 0; - if (top > 0.0001) { - ratio = diff / top; - } - return ratio < 0.005; -} - template static void blkq8_dequant_reference(const int8_t* src, float* dst, size_t M, size_t K) { const size_t num_blks = K / QBlkLen; diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp new file mode 100644 index 0000000000000..6c97d60301573 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -0,0 +1,270 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_sqnbitgemm.h + +Abstract: + + Tests for MLAS n-bit int block quantized GEMM. + +--*/ + +#include "test_util.h" +#include "mlas_q4.h" +#include "mlas_qnbit.h" + +/** + * @brief Test class for n-bit int block quantized GEMM + * Note: only 2-D matmul supported for now + */ +template +class MlasSQNBitGemmTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferA; + MatrixGuardBuffer BufferB; + MatrixGuardBuffer BufferQuantBData; + MatrixGuardBuffer BufferQuantBZeroPoint; + MatrixGuardBuffer BufferQuantBScale; + MatrixGuardBuffer BufferDequantizedB; + MatrixGuardBuffer BufferBias; + MatrixGuardBuffer BufferC; + MatrixGuardBuffer BufferCReference; + + void CallGemm(size_t M, + size_t N, + size_t K, + const float* A, + size_t lda, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + const float* Bias, + float* C, + size_t ldc, + MLAS_THREADPOOL* Threadpool) { + MLAS_SQNBIT_GEMM_DATA_PARAMS params; + params.A = A; + params.lda = lda; + params.Bias = Bias; + params.C = C; + params.ldc = ldc; + params.QuantBData = QuantBData; + params.QuantBScale = QuantBScale; + params.QuantBZeroPoint = QuantBZeroPoint; + params.PostProcessor = nullptr; + + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ¶ms, Threadpool); + } + + void CallReferenceGemm(size_t M, + size_t N, + size_t K, + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + const float* Bias, + float* C) { + float* DequantizedBData = BufferDequantizedB.GetBuffer(K * N); + MlasDequantizeBlockwise( + DequantizedBData, QuantBData, QuantBScale, QuantBZeroPoint, BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), GetMlasThreadPool()); + // Note: DequantizedBData is in column major layout. + + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const float* a = A + m * K; + const float* b = DequantizedBData + n * K; + float* c = C + (m * N) + n; + + float sum = Bias == nullptr ? 0.0f : Bias[n]; + for (size_t k = 0; k < K; k++) { + sum += (*a) * (*b); + b += 1; + a += 1; + } + *c = sum; + } + } + } + + public: + void Test(size_t M, size_t N, size_t K, + bool WithBias, bool Symmetric, bool WithThreadpool) { + MLAS_THREADPOOL* Threadpool = WithThreadpool ? GetMlasThreadPool() : nullptr; + + const float* A = BufferA.GetBuffer(K * M); + + const float* B = BufferB.GetBuffer(N * K); + + const float* Bias = nullptr; + if (WithBias) { + Bias = BufferBias.GetBuffer(N); + } + +#if 0 + auto print_matrix = [](size_t ncols, size_t nrows, const float* data) { + for (size_t row = 0; row < nrows; ++row) { + for (size_t col = 0; col < ncols; ++col) { + std::cout << data[row * nrows + col] << "\t"; + } + std::cout << "\n"; + } + }; + + std::cout << "A:\n"; + print_matrix(M, K, A); + std::cout << "B:\n"; + print_matrix(K, N, B); +#endif + + float* C = BufferC.GetBuffer(N * M, true); + float* CReference = BufferCReference.GetBuffer(N * M, true); + + // pack B + uint8_t* QuantBData = nullptr; + float* QuantBScale = nullptr; + uint8_t* QuantBZeroPoint = nullptr; + { + size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; + MlasBlockwiseQuantizedBufferSizes(BlkBitWidth, BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), + QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); + + QuantBData = BufferQuantBData.GetBuffer(QuantBDataSizeInBytes); + QuantBScale = BufferQuantBScale.GetBuffer(QuantBScaleSize); + if (Symmetric) { + QuantBZeroPoint = BufferQuantBZeroPoint.GetBuffer(QuantBZeroPointSizeInBytes); + } + + MlasQuantizeBlockwise(QuantBData, QuantBScale, QuantBZeroPoint, + B, BlkLen, + /* columnwise */ true, + static_cast(K), static_cast(N), + static_cast(N), + GetMlasThreadPool()); + } + + CallGemm(M, N, K, A, /* lda */ K, QuantBData, QuantBScale, QuantBZeroPoint, Bias, C, /* ldc */ N, Threadpool); + CallReferenceGemm(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); + + size_t f = 0; + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++, f++) { + ASSERT_TRUE(CloseEnough(C[f], CReference[f])) + << "Expected: " << CReference[f] << " Actual: " << C[f] << "@[" << m << "x" << n << "], " + << "M=" << M << ", N=" << N << ", K=" << K; + } + } + } + + public: + static const char* GetTestSuiteName() { + static std::string suite_name = std::string("SQNBitGemm") + + "BlkBitWidth" + std::to_string(BlkBitWidth) + + "BlkLen" + std::to_string(BlkLen); + return suite_name.c_str(); + } +}; + +// +// Short Execute() test helper to register each test separately by all parameters. +// +template +class SQNBitGemmShortExecuteTest : public MlasTestFixture> { + public: + explicit SQNBitGemmShortExecuteTest(size_t M, size_t N, size_t K, + bool WithThreadpool, bool Symmetric, bool WithBias) + : M_(M), N_(N), K_(K), WithThreadpool_(WithThreadpool), Symmetric_(Symmetric), WithBias_(WithBias) { + } + + void TestBody() override { + MlasTestFixture>::mlas_tester->Test( + M_, N_, K_, WithThreadpool_, Symmetric_, WithBias_); + } + + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, + bool WithThreadpool, bool Symmetric, bool WithBias) { + std::stringstream ss; + ss << (WithThreadpool ? "SingleThread" : "Threaded") + << "/isSymmetric" << Symmetric + << "/M" << M << "xN" << N << "xK" << K + << "/hasBias" << WithBias; + auto test_name = ss.str(); + + testing::RegisterTest( + MlasSQNBitGemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture>* { + return new SQNBitGemmShortExecuteTest( + M, N, K, WithThreadpool, Symmetric, WithBias); + }); + + return 1; + } + + static size_t RegisterShortExecuteTests() { + size_t test_registered = 0; + + if (MlasIsSQNBitGemmAvailable(BlkBitWidth, BlkLen)) { + for (bool WithThreadpool : {false, true}) { + for (bool Symmetric : {false, true}) { + for (size_t b = 1; b < 16; b++) { + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, false); + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true); + } + for (size_t b = 16; b <= 256; b <<= 1) { + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, false); + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true); + } + for (size_t b = 256; b < 320; b += 32) { + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true); + } + for (size_t b = 1; b < 96; b++) { + test_registered += RegisterSingleTest(1, b, 32, WithThreadpool, Symmetric, false); + test_registered += RegisterSingleTest(1, 32, b, WithThreadpool, Symmetric, true); + test_registered += RegisterSingleTest(1, b, b, WithThreadpool, Symmetric, false); + } + test_registered += RegisterSingleTest(43, 500, 401, WithThreadpool, Symmetric, true); + + // test_registered += RegisterSingleTest(1001, 1027, 1031, WithThreadpool, Symmetric, false); + } + } + } + + return test_registered; + } + + private: + size_t M_, N_, K_; + bool WithThreadpool_, Symmetric_, WithBias_; +}; + +static size_t SQNBitGemmRegisterAllShortExecuteTests() { + size_t count = 0; + + count += SQNBitGemmShortExecuteTest<4, 16>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest<4, 32>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest<4, 64>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest<4, 128>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest<4, 256>::RegisterShortExecuteTests(); + + return count; +} + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + if (is_short_execute) { + return SQNBitGemmRegisterAllShortExecuteTests() > 0; + } + return false; +}); diff --git a/onnxruntime/test/mlas/unittest/test_util.h b/onnxruntime/test/mlas/unittest/test_util.h index db528ef7291cc..8eecda900ff27 100644 --- a/onnxruntime/test/mlas/unittest/test_util.h +++ b/onnxruntime/test/mlas/unittest/test_util.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -253,3 +254,16 @@ inline void ReorderInputNchw(const int64_t* input_shape, const float* S, float* D += spatial_count * nchwc_channel_count; } } + +inline bool CloseEnough(float actual, float expected) { + if (std::isnan(actual)) { + return std::isnan(expected); + } + float diff = std::abs(actual - expected); + float top = std::max(std::abs(actual), std::abs(expected)); + float ratio = 0; + if (top > 0.0001) { + ratio = diff / top; + } + return ratio < 0.005; +} diff --git a/onnxruntime/test/onnx/microbenchmark/reduceminmax.cc b/onnxruntime/test/onnx/microbenchmark/reduceminmax.cc index bd2abadf49b81..d866045ba4962 100644 --- a/onnxruntime/test/onnx/microbenchmark/reduceminmax.cc +++ b/onnxruntime/test/onnx/microbenchmark/reduceminmax.cc @@ -91,6 +91,8 @@ BENCHMARK(BM_FindMinMaxMlasSSE2) ->Arg(98304) ->Arg(160000); +#ifdef MLAS_TARGET_AMD64 + // MLAS avx implementation static void BM_FindMinMaxMlasAvx(benchmark::State& state) { const size_t batch_size = static_cast(state.range(0)); @@ -115,3 +117,5 @@ BENCHMARK(BM_FindMinMaxMlasAvx) ->Arg(80000) ->Arg(98304) ->Arg(160000); + +#endif // MLAS_TARGET_AMD64 diff --git a/onnxruntime/test/optimizer/compute_optimizer_test.cc b/onnxruntime/test/optimizer/compute_optimizer_test.cc index a03d0da2538d4..9dcedd1fd7681 100644 --- a/onnxruntime/test/optimizer/compute_optimizer_test.cc +++ b/onnxruntime/test/optimizer/compute_optimizer_test.cc @@ -847,7 +847,7 @@ Test graph includes multiple equivalent subgraphs as below. Add an Identity node because currently, we don't allow Gather generates graph output. */ TEST(ComputeOptimizerTests, GatherLayerNormalization) { - std::vector> test_config_pairs{ + std::vector> test_config_pairs{ // { // is_scalar_slice, // ln_axis_before_propagation, @@ -929,13 +929,6 @@ TEST(ComputeOptimizerTests, GatherLayerNormalization) { const ONNX_NAMESPACE::TensorShapeProto* slice_out_shape = producer_node->OutputDefs()[0]->Shape(); TEST_RETURN_IF_NOT(slice_out_shape != nullptr); - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - - auto& axis_attr = attrs.at("axis"); - auto axis_value = (int)axis_attr.i(); - TEST_RETURN_IF_NOT(axis_value == ln_axis_after); - if (is_scalar_slice) { TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 2); TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && @@ -951,10 +944,15 @@ TEST(ComputeOptimizerTests, GatherLayerNormalization) { TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(2)) && slice_out_shape->dim(2).dim_value() == 256); } - } else { TEST_RETURN_IF_NOT(producer_node == nullptr); } + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + TEST_RETURN_IF_NOT(axis_value == ln_axis_after); } } @@ -2841,165 +2839,110 @@ Test graph include multiple equivalent subgraphs as below. Add an Identity node because currently we don't allow Reshape generate graph output. */ -TEST(ComputeOptimizerTests, ReshapeLayerNormalization_PropagationOnOneBranch) { - const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); - auto pre_graph_checker = [](Graph& graph) -> Status { - auto op_count_pre = CountOpsInGraph(graph); - TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); - TEST_RETURN_IF_NOT(op_count_pre["LayerNormalization"] == 1); - TEST_RETURN_IF_NOT(op_count_pre["Reshape"] == 1); - TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); - return Status::OK(); - }; - - auto post_graph_checker = [](Graph& graph) { - auto op_count_post = CountOpsInGraph(graph); - TEST_RETURN_IF_NOT(op_count_post.size() == 3U); - TEST_RETURN_IF_NOT(op_count_post["LayerNormalization"] == 1); - TEST_RETURN_IF_NOT(op_count_post["Reshape"] == 1); - TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); - - for (Node& node : graph.Nodes()) { - if (node.OpType() == "LayerNormalization") { - const auto& input_defs = node.InputDefs(); - - { - auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); - TEST_RETURN_IF_NOT(producer_node != nullptr); - TEST_RETURN_IF_NOT(producer_node->OpType() == "Reshape"); - - InlinedVector values; - constexpr bool require_constant = true; - NodeArg* initializer_node_arg = graph.GetNodeArg(producer_node->InputDefs()[1]->Name()); - TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, values, require_constant)); - TEST_RETURN_IF_NOT(values.size() == 2); - TEST_RETURN_IF_NOT(values[0] == -1); - TEST_RETURN_IF_NOT(values[1] == 1024); - } - - { - auto producer_node = graph.GetProducerNode(input_defs[1]->Name()); - TEST_RETURN_IF_NOT(producer_node == nullptr); - } - - { - auto producer_node = graph.GetProducerNode(input_defs[2]->Name()); - TEST_RETURN_IF_NOT(producer_node == nullptr); - } - } - } - return Status::OK(); +TEST(ComputeOptimizerTests, ReshapeLayerNormalization) { + std::vector> test_config_pairs{ + // { + // ln_axis_before_propagation, + // expected_ln_axis_after_propagation, + // expected to propagate + // } + {0, 0, false}, + {1, 1, false}, + {2, 1, true}, + {-3, -3, false}, + {-2, -2, false}, + {-1, -1, true}, }; - std::vector fist_dim_values = {-1, 128}; - for (auto first_dim_value : fist_dim_values) { - auto build_test_case = [&first_dim_value](ModelTestBuilder& builder) { - auto* input1_arg = builder.MakeInput({{4, 32, 1024}}); - auto* input2_arg = builder.MakeInput({{1024}}); - auto* input3_arg = builder.MakeInput({{1024}}); - auto* ln_out = builder.MakeIntermediate(); - builder.AddNode("LayerNormalization", {input1_arg, input2_arg, input3_arg}, {ln_out}) - .AddAttribute("axis", static_cast(-1)); - - auto* shape_initializer = builder.MakeInitializer({2}, {first_dim_value, 1024}); - auto* reshape_out = builder.MakeIntermediate(); - builder.AddNode("Reshape", {ln_out, shape_initializer}, {reshape_out}); + for (auto p : test_config_pairs) { + int64_t ln_axis_before = std::get<0>(p); + int64_t ln_axis_after = std::get<1>(p); + bool expected_to_propagate = std::get<2>(p); - auto* identity_out = builder.MakeOutput(); - builder.AddNode("Identity", {reshape_out}, {identity_out}); + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto pre_graph_checker = [](Graph& graph) -> Status { + auto op_count_pre = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); + TEST_RETURN_IF_NOT(op_count_pre["LayerNormalization"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Reshape"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); + return Status::OK(); }; - const std::vector opsets{12, 13, 14}; - for (auto& opset_version : opsets) { - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger, std::move(transformer), - TransformerLevel::Level1, - 1, pre_graph_checker, post_graph_checker)); - } - } -} + auto post_graph_checker = [ln_axis_after, expected_to_propagate](Graph& graph) { + auto op_count_post = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_post.size() == 3U); + TEST_RETURN_IF_NOT(op_count_post["LayerNormalization"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Reshape"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); -/* -Test graph include multiple equivalent subgraphs as below. - graph input [4, 32, 1024] (float) graph input [1024] (float) graph input [1024] (float) - | | / - \_____________ _______/ __________________________/ - \ / / - LayerNormalization - | - Reshape - | - Identity - | - graph out [128, 1024] (float) + for (Node& node : graph.Nodes()) { + if (node.OpType() == "LayerNormalization") { + const auto& input_defs = node.InputDefs(); -Add an Identity node because currently we don't allow Reshape generate graph output. -*/ -TEST(ComputeOptimizerTests, ReshapeLayerNormalization_NoPropagation) { - const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); - auto pre_graph_checker = [](Graph& graph) -> Status { - auto op_count_pre = CountOpsInGraph(graph); - TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); - TEST_RETURN_IF_NOT(op_count_pre["LayerNormalization"] == 1); - TEST_RETURN_IF_NOT(op_count_pre["Reshape"] == 1); - TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); - return Status::OK(); - }; + if (expected_to_propagate) { + auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); + TEST_RETURN_IF_NOT(producer_node != nullptr); + TEST_RETURN_IF_NOT(producer_node->OpType() == "Reshape"); - auto post_graph_checker = [](Graph& graph) { - auto op_count_post = CountOpsInGraph(graph); - TEST_RETURN_IF_NOT(op_count_post.size() == 3U); - TEST_RETURN_IF_NOT(op_count_post["LayerNormalization"] == 1); - TEST_RETURN_IF_NOT(op_count_post["Reshape"] == 1); - TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); + InlinedVector values; + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(producer_node->InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, values, require_constant)); + TEST_RETURN_IF_NOT(values.size() == 2); + TEST_RETURN_IF_NOT(values[0] == -1); + TEST_RETURN_IF_NOT(values[1] == 1024); + } else { + auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); + TEST_RETURN_IF_NOT(producer_node == nullptr); + } - for (Node& node : graph.Nodes()) { - if (node.OpType() == "LayerNormalization") { - const auto& input_defs = node.InputDefs(); + { + auto producer_node = graph.GetProducerNode(input_defs[1]->Name()); + TEST_RETURN_IF_NOT(producer_node == nullptr); + } - { - auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); - TEST_RETURN_IF_NOT(producer_node == nullptr); - } + { + auto producer_node = graph.GetProducerNode(input_defs[2]->Name()); + TEST_RETURN_IF_NOT(producer_node == nullptr); + } - { - auto producer_node = graph.GetProducerNode(input_defs[1]->Name()); - TEST_RETURN_IF_NOT(producer_node == nullptr); - } + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - { - auto producer_node = graph.GetProducerNode(input_defs[2]->Name()); - TEST_RETURN_IF_NOT(producer_node == nullptr); + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + TEST_RETURN_IF_NOT(axis_value == ln_axis_after); } } - } - return Status::OK(); - }; - - std::vector fist_dim_values = {-1, 128}; - for (auto first_dim_value : fist_dim_values) { - auto build_test_case = [&first_dim_value](ModelTestBuilder& builder) { - auto* input1_arg = builder.MakeInput({{4, 32, 1024}}); - auto* input2_arg = builder.MakeInput({{1024}}); - auto* input3_arg = builder.MakeInput({{1024}}); - auto* ln_out = builder.MakeIntermediate(); - builder.AddNode("LayerNormalization", {input1_arg, input2_arg, input3_arg}, {ln_out}) - .AddAttribute("axis", static_cast(1)); - - auto* shape_initializer = builder.MakeInitializer({2}, {first_dim_value, 1024}); - auto* reshape_out = builder.MakeIntermediate(); - builder.AddNode("Reshape", {ln_out, shape_initializer}, {reshape_out}); - - auto* identity_out = builder.MakeOutput(); - builder.AddNode("Identity", {reshape_out}, {identity_out}); + return Status::OK(); }; - const std::vector opsets{12, 13, 14}; - for (auto& opset_version : opsets) { - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger, std::move(transformer), - TransformerLevel::Level1, - 1, pre_graph_checker, post_graph_checker)); + std::vector fist_dim_values = {-1, 128}; + for (auto first_dim_value : fist_dim_values) { + auto build_test_case = [ln_axis_before, &first_dim_value](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{4, 32, 1024}}); + auto* input2_arg = builder.MakeInput({{1024}}); + auto* input3_arg = builder.MakeInput({{1024}}); + auto* ln_out = builder.MakeIntermediate(); + builder.AddNode("LayerNormalization", {input1_arg, input2_arg, input3_arg}, {ln_out}) + .AddAttribute("axis", ln_axis_before); + + auto* shape_initializer = builder.MakeInitializer({2}, {first_dim_value, 1024}); + auto* reshape_out = builder.MakeIntermediate(); + builder.AddNode("Reshape", {ln_out, shape_initializer}, {reshape_out}); + + auto* identity_out = builder.MakeOutput(); + builder.AddNode("Identity", {reshape_out}, {identity_out}); + }; + + const std::vector opsets{12, 13, 14}; + for (auto& opset_version : opsets) { + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger, std::move(transformer), + TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); + } } } } diff --git a/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc b/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc index feff607703341..7a67747f7cf4c 100644 --- a/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc +++ b/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc @@ -63,7 +63,8 @@ std::function GetGraphBuilder(const GraphConfig& config return graph.ToGraphProto(); }; - auto* if_input = builder.MakeInitializerBool({}, {true}); + // Make this an input to prevent If constant folding affecting this test + auto* if_input = builder.MakeInput({1}, {true}); auto* if_output = builder.MakeOutput(); Node& if_node = builder.AddNode("If", {if_input}, {if_output}); if_node.AddAttribute("then_branch", create_if_subgraph(true)); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index e0f63ea58e772..ef6e2d531bc1a 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -10,6 +10,8 @@ #include "gtest/gtest.h" #include "gmock/gmock.h" +#include "onnx/defs/parser.h" +#include "onnx/defs/printer.h" #include "asserts.h" #include "core/common/span_utils.h" @@ -1022,6 +1024,314 @@ TEST_F(GraphTransformationTests, ConstantFoldingStringInitializer) { ASSERT_EQ(op_to_count.size(), 0U) << "Identity node should have been removed"; } +TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInlining) { + // This test covers the following necessary cases: + // The input refers to the explicit or implicit inputs of If node. + // The output of the node is the output of the subgraph being inlined. + // Constant nodes and initializers are promoted to the outer graph. + // The initializer or a constant node is the output of the subgraph being inlined. + // Nested subgraphs names are renamed as appropriate. + // In all If node is constant folded twice. The last If node is not constant + // folded because the input is indirectly dependent on the size of the input. + // XXX: Can we constant fold Size() if the graph input shape is fixed? + + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[128] x, float[128] x1) => (float[N] y) + { + y = local.aten_gather (x, x1) + } + < + opset_import: [ "" : 16, "local" : 1], + domain: "local" + > + aten_gather (self, index) => (result_16) + { + tmp = Shape (index) + tmp_0 = Size (tmp) + int64_0 = Constant () + int64_0_cast = CastLike (int64_0, tmp_0) + cond = Equal (tmp_0, int64_0_cast) + result_16 = If (cond) ( result) { + result = Identity (self) + }, else_branch: graph = elseGraph_10 () => ( result_15) { + tmp_1 = Shape (self) + tmp_2 = Size (tmp_1) + int64_0_3 = Constant () + int64_0_3_cast = CastLike (int64_0_3, tmp_2) + cond_4 = Equal (tmp_2, int64_0_3_cast) + self_8 = If (cond_4) ( self_6) { + tmp_5 = Constant () + self_6 = Reshape (self, tmp_5) + }, else_branch: graph = elseGraph_13 () => ( self_7) { + self_7 = Identity (self) + }> + tmp_9 = Size (index) + int64_0_10 = Constant () + int64_0_10_cast = CastLike (int64_0_10, tmp_9) + cond_11 = Equal (tmp_9, int64_0_10_cast) + result_15 = If (cond_11) ( result_12) { + result_12 = CastLike (index, self_8) + }, else_branch: graph = elseGraph_15 () => ( result_14) { + index_13 = Cast (index) + result_14 = GatherElements (self_8, index_13) + }> + }> + } +)"; + + ONNX_NAMESPACE::OnnxParser parser(code); + ONNX_NAMESPACE::ModelProto model_proto; + auto parse_status = parser.Parse(model_proto); + ASSERT_TRUE(parse_status.IsOK()) << parse_status.ErrorMessage(); + ASSERT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected."; + + { + // Test that the model is loadable and check the function call node. + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(std::move(model_proto), p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["local.aten_gather"], 1); + model_proto = p_model->ToProto(); + } + + std::string serialized_model; + const bool serialization_status = model_proto.SerializeToString(&serialized_model); + ASSERT_TRUE(serialization_status) << "Failed to serialize proto to string"; + + // AOT inlining is necessary in this case, so the If nodes within the function + // are brought out to the outer scope. So we load this into a session object. + + SessionOptions session_options; + InferenceSessionWrapper session_object{session_options, GetEnvironment()}; + + std::stringstream sstr(serialized_model); + ASSERT_STATUS_OK(session_object.Load(sstr)); + ASSERT_STATUS_OK(session_object.Initialize()); + + // const auto resulting_model_proto = session_object.GetModel().ToProto(); + // std::string printed_model = ONNX_NAMESPACE::ProtoToString(resulting_model_proto); + // ASSERT_FALSE(printed_model.empty()); + // std::cout << printed_model << std::endl; + + // This is the resulting model proto. + // The remaining If node is not constant foldable because Size() does not constant fold + // although the shape is fixed. + /* + < + ir_version: 8, + opset_import: ["" : 16, "local" : 1, + "com.microsoft.nchwc" : 1, + "ai.onnx.ml" : 4, + "com.ms.internal.nhwc" : 20, + "ai.onnx.training" : 1, + "ai.onnx.preview.training" : 1, + "com.microsoft" : 1, + "com.microsoft.experimental" : 1, + "org.pytorch.aten" : 1] + > + agraph (float[128] x, float[128] x1) => (float[128] y) { + _if_elseGraph_10__inlfunc_aten_gather_tmp_9 = Size (x1) + _if_elseGraph_10__inlfunc_aten_gather_cond_11 = + Equal (_if_elseGraph_10__inlfunc_aten_gather_tmp_9, ortshared_7_0_1_0_token_10) + y = If (_if_elseGraph_10__inlfunc_aten_gather_cond_11) (float[128] _inlfunc_aten_gather_result_12) { + _inlfunc_aten_gather_result_12 = Cast (x1) + }, else_branch: graph = elseGraph_15 () => (float[128] _inlfunc_aten_gather_result_14) { + _inlfunc_aten_gather_index_13 = Cast (x1) + _inlfunc_aten_gather_result_14 = GatherElements (x, _inlfunc_aten_gather_index_13) + }> + } + */ + + auto& graph = session_object.GetModel().MainGraph(); + auto op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["local.aten_gather"], 0); + ASSERT_EQ(op_to_count["If"], 1); +} + +TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningRebuildEdges) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "transform_nested_ifs_toplogical_sorted_nodes.onnx"; + + SessionOptions so; + so.session_logid = "GraphTransformationTests.ConstantFoldingIfConstantInliningRebuildEdges"; + + SessionOptions session_options; + InferenceSessionWrapper session_object{session_options, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(model_uri)); + ASSERT_STATUS_OK(session_object.Initialize()); + + auto& graph = session_object.GetModel().MainGraph(); + auto op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["pkg.onnxscript.torch_lib._aten_linalg_vector_norm_no_dim_onnx"], 0); + ASSERT_EQ(op_to_count["If"], 0); + ASSERT_EQ(op_to_count["Reshape"], 1); + ASSERT_EQ(op_to_count["Abs"], 1); + ASSERT_EQ(op_to_count["Mul"], 1); + ASSERT_EQ(op_to_count["ReduceSum"], 1); + ASSERT_EQ(op_to_count["Sqrt"], 1); + ASSERT_EQ(op_to_count["Cast"], 2); +} + +TEST_F(GraphTransformationTests, ConstantFoldingIfConstantInliningEdgesWithMiddleArgNonExisting) { + // This model has a Resize() call with a middle argument non-existing. + // We want to make sure that the input edges for that Resize() node + // are properly rebuilt with a middle argument non-existing + // during If constant folding + // This test is only valid if Resize() node resides in the nested subgraph which gets inlined + // however, the destination graph must not be the main graph. Then we test that the edges are rebuild + // properly. Also Resize() should not be the first node in the resulting subgraph, so it has edges + const char* code = R"( + < + ir_version: 8, + opset_import: [ "" : 16, "local" : 1 ] + > + agraph (float[128] x, float[128] x1) => (float[N] y) + { + y = local.aten_gather (x, x1) + } + < + opset_import: [ "" : 16, "local" : 1], + domain: "local" + > + aten_gather (self, index) => (result_16) + { + resize_scales = Constant () + tmp_0 = Size (index) + int64_0 = Constant () + int64_0_cast = CastLike (int64_0, tmp_0) + cond = Equal (tmp_0, int64_0_cast) + result_16 = If (cond) ( result) { + result = Identity (self) + }, else_branch: graph = elseGraph_10 () => ( result_15) { + tmp_1 = Shape (self) + tmp_2 = Size (tmp_1) + int64_0_3 = Constant () + int64_0_3_cast = CastLike (int64_0_3, tmp_2) + cond_4 = Equal (tmp_2, int64_0_3_cast) + self_8 = If (cond_4) ( self_6) { + tmp_5 = Constant () + self_6 = Reshape (self, tmp_5) + }, else_branch: graph = elseGraph_13 () => ( self_7) { + self_71 = Mul(self, self) + float_size = CastLike (tmp_0, resize_scales) + non_constant_resize_scales = Mul(float_size, resize_scales) + self_7 = Resize(self_71,, non_constant_resize_scales) + }> + tmp_9 = Size (index) + int64_0_10 = Constant () + int64_0_10_cast = CastLike (int64_0_10, tmp_9) + cond_11 = Equal (tmp_9, int64_0_10_cast) + result_15 = If (cond_11) ( result_12) { + result_12 = CastLike (index, self_8) + }, else_branch: graph = elseGraph_15 () => ( result_14) { + index_13 = Cast (index) + result_14 = GatherElements (self_8, index_13) + }> + }> + } + )"; + + /** Optimized model graph + < + ir_version: 8, + opset_import: ["" : 16, + "local" : 1, + "com.microsoft.nchwc" : 1, + "ai.onnx.ml" : 4, + "ai.onnx.training" : 1, + "ai.onnx.preview.training" : 1, + "com.microsoft" : 1, + "com.microsoft.experimental" : 1, "org.pytorch.aten" : 1] + > + agraph (float[128] x, float[128] x1) => (float[128] y) + + { + _inlfunc_aten_gather_tmp_0 = Size (x1) + _inlfunc_aten_gather_cond = Equal (_inlfunc_aten_gather_tmp_0, ortshared_7_0_1_0_token_8) + y = If (_inlfunc_aten_gather_cond) + (float[128] _inlfunc_aten_gather_result) { + _inlfunc_aten_gather_result = Identity (x) + }, else_branch: graph = elseGraph_10 () => (float[128] _inlfunc_aten_gather_result_15) + + { + _if_else_branch__inlfunc_aten_gather_self_71 = Mul (x, x) + _if_else_branch__inlfunc_aten_gather_float_size = Cast (_inlfunc_aten_gather_tmp_0) + _if_else_branch__inlfunc_aten_gather_non_constant_resize_scales = Mul ( + _if_else_branch__inlfunc_aten_gather_float_size, _inlfunc_aten_gather_resize_scales) + _inlfunc_aten_gather_self_8 = Resize ( + _if_else_branch__inlfunc_aten_gather_self_71, , + _if_else_branch__inlfunc_aten_gather_non_constant_resize_scales) + _inlfunc_aten_gather_tmp_9 = Size (x1) + _inlfunc_aten_gather_cond_11 = Equal (_inlfunc_aten_gather_tmp_9, _inlfunc_aten_gather_int64_0_10) + _inlfunc_aten_gather_result_15 = If (_inlfunc_aten_gather_cond_11) + (float[128] _inlfunc_aten_gather_result_12) { + _inlfunc_aten_gather_result_12 = Cast (x1) + }, else_branch: graph = elseGraph_15 () => (float[128] _inlfunc_aten_gather_result_14) { + _inlfunc_aten_gather_index_13 = Cast (x1) + _inlfunc_aten_gather_result_14 = GatherElements ( + _inlfunc_aten_gather_self_8, _inlfunc_aten_gather_index_13) + }> + }> + } + + */ + + ONNX_NAMESPACE::OnnxParser parser(code); + ONNX_NAMESPACE::ModelProto model_proto; + auto parse_status = parser.Parse(model_proto); + ASSERT_TRUE(parse_status.IsOK()) << parse_status.ErrorMessage(); + ASSERT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected."; + + std::string serialized_model; + const bool serialization_status = model_proto.SerializeToString(&serialized_model); + ASSERT_TRUE(serialization_status) << "Failed to serialize proto to string"; + + // AOT inlining is necessary in this case, so the If nodes within the function + // are brought out to the outer scope. So we load this into a session object. + SessionOptions session_options; + InferenceSessionWrapper session_object{session_options, GetEnvironment()}; + std::stringstream sstr(serialized_model); + ASSERT_STATUS_OK(session_object.Load(sstr)); + ASSERT_STATUS_OK(session_object.Initialize()); + + // Let's verify the correctness of the rebuild edges in the Resize node that still + // resides within an if else subgraph. + auto& graph = session_object.GetModel().MainGraph(); + auto op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["If"], 2); + ASSERT_EQ(op_to_count["Resize"], 1); + + auto if_node = std::find_if(graph.Nodes().begin(), graph.Nodes().end(), + [](const auto& node) { return node.OpType() == "If"; }); + ASSERT_NE(graph.Nodes().cend(), if_node); + // Resize is in the else branch + auto subgraph_map = if_node->GetAttributeNameToSubgraphMap(); + auto branch = subgraph_map.find("else_branch"); + ASSERT_NE(subgraph_map.cend(), branch); + + auto resize_node = std::find_if(branch->second->Nodes().begin(), branch->second->Nodes().end(), + [](const auto& node) { return node.OpType() == "Resize"; }); + ASSERT_NE(branch->second->Nodes().cend(), resize_node); + + // Check the edges + ASSERT_EQ(2U, resize_node->GetInputEdgesCount()); + // Should have input edges with arg_pos 0 and 2 + // With 1 is missing + InlinedHashSet dest_edges; + auto zero_edge = resize_node->InputEdgesBegin(); + dest_edges.insert(zero_edge->GetDstArgIndex()); + ++zero_edge; + dest_edges.insert(zero_edge->GetDstArgIndex()); + ASSERT_TRUE(dest_edges.find(0) != dest_edges.end()); + ASSERT_TRUE(dest_edges.find(2) != dest_edges.end()); +} + // Check transformations in the case of a subgraph with constant inputs. TEST_F(GraphTransformationTests, SubgraphWithConstantInputs) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "constant-subgraph.onnx"; @@ -6756,7 +7066,10 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); }; - auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); }; + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; // OpSet-12 { @@ -6779,8 +7092,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-14 @@ -6808,8 +7121,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-18 @@ -6837,8 +7150,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } } @@ -6869,7 +7182,10 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); }; - auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); }; + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; // OpSet-12 { @@ -6888,8 +7204,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-14 @@ -6909,8 +7225,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-18 @@ -6930,13 +7246,180 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } +} + +TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Input) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); + auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); + auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(2)}); + auto* gather_out_1 = builder.MakeIntermediate(); + auto* gather_out_2 = builder.MakeIntermediate(); + auto* gather_out_3 = builder.MakeIntermediate(); + auto* transpose_out_1 = builder.MakeOutput(); + auto* transpose_out_2 = builder.MakeOutput(); + auto* transpose_out_3 = builder.MakeOutput(); + + builder.AddNode("Gather", {data_arg, gather_index_1}, {gather_out_1}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Gather", {data_arg, gather_index_2}, {gather_out_2}) + .AddAttribute("axis", static_cast(-2)); + builder.AddNode("Gather", {data_arg, gather_index_3}, {gather_out_3}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; + + // OpSet-12 + { + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axes") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axes").ints().at(0))); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } + + // OpSet-14 + { + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + const NodeArg& input_arg = *(node.InputDefs()[1]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = + graph_utils::GetConstantInitializer(graph, input_arg.Name()); + TEST_RETURN_IF_NOT(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); + TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } + + // OpSet-18 + { + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + const NodeArg& input_arg = *(node.InputDefs()[1]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = + graph_utils::GetConstantInitializer(graph, input_arg.Name()); + TEST_RETURN_IF_NOT(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); + TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } +} + +TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Initializer) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInitializer({2, 3, 3, 3}, std::vector(54)); + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); + auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); + auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(2)}); + auto* gather_out_1 = builder.MakeIntermediate(); + auto* gather_out_2 = builder.MakeIntermediate(); + auto* gather_out_3 = builder.MakeIntermediate(); + auto* transpose_out_1 = builder.MakeOutput(); + auto* transpose_out_2 = builder.MakeOutput(); + auto* transpose_out_3 = builder.MakeOutput(); + + builder.AddNode("Gather", {data_arg, gather_index_1}, {gather_out_1}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Gather", {data_arg, gather_index_2}, {gather_out_2}) + .AddAttribute("axis", static_cast(-2)); + builder.AddNode("Gather", {data_arg, gather_index_3}, {gather_out_3}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + const NodeArg& input_arg = *(node.InputDefs()[1]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); + TEST_RETURN_IF_NOT(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); + TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); } TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { - auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); }; + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; auto post_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0); @@ -6976,8 +7459,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // Invalid Gather indices. @@ -7012,8 +7495,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // Invalid Gather axis. @@ -7048,8 +7531,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } } @@ -7096,8 +7579,8 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-14, Tind is int64. @@ -7135,8 +7618,8 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } } diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index 2008d96539dca..e64117925eb57 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -466,11 +466,11 @@ GetQDQTestCaseFn BuildDoubleQDQWithoutLastOutput(int output_index, bool use_cont } template -GetQDQTestCaseFn BuildQDQSplitTestCase( - const std::vector& input_shape, - const int64_t& axis, - bool use_contrib_qdq = false) { - return [input_shape, axis, use_contrib_qdq](ModelTestBuilder& builder) { +GetQDQTestCaseFn BuildQDQSplitTestCase(const std::vector& input_shape, + const int64_t& axis, + bool use_diff_output_scale, + bool use_contrib_qdq = false) { + return [input_shape, axis, use_diff_output_scale, use_contrib_qdq](ModelTestBuilder& builder) { auto* input_arg = builder.MakeInput(input_shape, std::numeric_limits::min(), std::numeric_limits::max()); @@ -478,16 +478,30 @@ GetQDQTestCaseFn BuildQDQSplitTestCase( InputType dq_zp = std::numeric_limits::max() / 2; OutputType q_zp = std::numeric_limits::max() / 2; auto* dq_output = builder.MakeIntermediate(); - builder.AddDequantizeLinearNode(input_arg, .003f, dq_zp, dq_output, use_contrib_qdq); + constexpr float input_scale = 0.003f; + builder.AddDequantizeLinearNode(input_arg, input_scale, dq_zp, dq_output, use_contrib_qdq); // add Split + std::vector split_inputs; + split_inputs.push_back(dq_output); + + // Use the optional 'split' input when testing Split 13 + int opset = builder.DomainToVersionMap().find(kOnnxDomain)->second; + if (opset >= 13 && opset < 18) { + int64_t dim = input_shape[axis]; + int64_t split_size = dim / 3; + split_inputs.push_back(builder.Make1DInitializer(std::vector{split_size, + split_size, dim - (2 * split_size)})); + } auto* split_output_1 = builder.MakeIntermediate(); auto* split_output_2 = builder.MakeIntermediate(); auto* split_output_3 = builder.MakeIntermediate(); - Node& split_node = builder.AddNode("Split", {dq_output}, {split_output_1, split_output_2, split_output_3}); + Node& split_node = builder.AddNode("Split", split_inputs, {split_output_1, split_output_2, split_output_3}); split_node.AddAttribute("axis", axis); - if (builder.DomainToVersionMap().find(kOnnxDomain)->second >= 18) { + + // Use the 'num_outputs' attribute when testing Split >= 18 + if (opset >= 18) { split_node.AddAttribute("num_outputs", static_cast(3)); } @@ -495,11 +509,12 @@ GetQDQTestCaseFn BuildQDQSplitTestCase( auto* q_split_output_1 = builder.MakeOutput(); auto* q_split_output_2 = builder.MakeOutput(); auto* q_split_output_3 = builder.MakeOutput(); - builder.AddQuantizeLinearNode(split_output_1, .003f, q_zp, q_split_output_1, + float output_scale = use_diff_output_scale ? input_scale + 0.001f : input_scale; + builder.AddQuantizeLinearNode(split_output_1, output_scale, q_zp, q_split_output_1, use_contrib_qdq); // Model input (node_token_1) - builder.AddQuantizeLinearNode(split_output_2, .003f, q_zp, q_split_output_2, + builder.AddQuantizeLinearNode(split_output_2, output_scale, q_zp, q_split_output_2, use_contrib_qdq); // Model input (node_token_2) - builder.AddQuantizeLinearNode(split_output_3, .003f, q_zp, q_split_output_3, + builder.AddQuantizeLinearNode(split_output_3, output_scale, q_zp, q_split_output_3, use_contrib_qdq); }; } diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 1bf1cbacf479e..17dd2e80f9f88 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -1210,27 +1210,51 @@ TEST(QDQTransformerTests, DoubleQDQ_Without_Last_Node_Being_Output) { // Runs a test that checks if DQ -> Split -> Q (many) is replaced with just Split. template static void RunDropSplitQDQTestCase(const std::vector& input_shape, int64_t axis, - bool use_contrib_qdq = false) { - auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + bool all_same_quant_params, bool use_contrib_qdq = false) { + auto check_graph = [all_same_quant_params, use_contrib_qdq](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + int expected_q_ops = all_same_quant_params ? 0 : 3; + int expected_dq_ops = all_same_quant_params ? 0 : 1; EXPECT_EQ(op_to_count["Split"], 1); - EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); - EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], expected_q_ops); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], expected_dq_ops); }; - TransformerTester(BuildQDQSplitTestCase(input_shape, axis, use_contrib_qdq), + TransformerTester(BuildQDQSplitTestCase(input_shape, axis, !all_same_quant_params, + use_contrib_qdq), check_graph, TransformerLevel::Level1, TransformerLevel::Level2, - {12, 18, 19}); + {12, 13, 18, 19}); // Test different ways to specify the split in each opset: + // 12 - split into equal parts without explicit 'split' attribute + // 13 - use optional 'split' input to split into 3 parts + // 18 - use 'num_outputs' attribute to split into 3 parts + // 19 - use 'num_outputs' attribute to split into 3 parts } // Test that DQ -> Split -> Q (many) is replaced with just Split for various quantization types. TEST(QDQTransformerTests, Split) { - RunDropSplitQDQTestCase({6, 18, 54}, 0); - RunDropSplitQDQTestCase({6, 18, 54}, 0, true); // Use com.microsoft int8 QDQ ops - RunDropSplitQDQTestCase({6, 18, 54}, 0, true); // Use com.microsoft int16 QDQ ops - RunDropSplitQDQTestCase({6, 18, 54}, 0, true); // Use com.microsoft uint16 QDQ ops + // Test cases that drop Q/DQ ops from DQ -> Split -> Q (many). + // This happens when all the Q/DQ ops have equal and constant quantization parameters. + { + constexpr bool ALL_SAME_QUANT_PARAMS = true; + constexpr bool USE_CONTRIB_QDQ_OPS = true; + RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS); + RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); + RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); + RunDropSplitQDQTestCase({6, 18, 54}, 0, ALL_SAME_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); + } + + // Test cases that DO NOT drop Q/DQ ops from DQ -> Split -> Q (many) + // This happens when the Q/DQ ops do not have equal and constant quantization parameters. + { + constexpr bool DIFF_QUANT_PARAMS = false; + constexpr bool USE_CONTRIB_QDQ_OPS = true; + RunDropSplitQDQTestCase({6, 18, 54}, 0, DIFF_QUANT_PARAMS); + RunDropSplitQDQTestCase({6, 18, 54}, 0, DIFF_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); + RunDropSplitQDQTestCase({6, 18, 54}, 0, DIFF_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); + RunDropSplitQDQTestCase({6, 18, 54}, 0, DIFF_QUANT_PARAMS, USE_CONTRIB_QDQ_OPS); + } } // Because split isn't one the supported ops, this will stay the same diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index c2e7577a7ca52..859e082716760 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include @@ -410,7 +411,7 @@ ::std::vector<::std::basic_string> GetParameterStrings() { // If an EP doesn't have any CI build pipeline, then there is no need to specify any opset. #ifdef USE_TENSORRT // tensorrt: only enable opset 14 to 17 of onnx tests - provider_names[provider_name_tensorrt] = {opset14, opset15, opset16, opset17}; + provider_names[provider_name_tensorrt] = {opset12, opset14, opset15, opset16, opset17}; #endif #ifdef USE_MIGRAPHX provider_names[provider_name_migraphx] = {opset7, opset8, opset9, opset10, opset11, opset12, opset13, opset14, opset15, opset16, opset17, opset18}; @@ -592,37 +593,10 @@ ::std::vector<::std::basic_string> GetParameterStrings() { ORT_TSTR("mul_uint8"), ORT_TSTR("div_uint8")}; static const ORTCHAR_T* tensorrt_disabled_tests[] = { - ORT_TSTR("udnie"), - ORT_TSTR("rain_princess"), - ORT_TSTR("pointilism"), - ORT_TSTR("mosaic"), - ORT_TSTR("LSTM_Seq_lens_unpacked"), - ORT_TSTR("cgan"), - ORT_TSTR("candy"), - ORT_TSTR("tinyyolov3"), - ORT_TSTR("yolov3"), - ORT_TSTR("mlperf_ssd_resnet34_1200"), - ORT_TSTR("mlperf_ssd_mobilenet_300"), - ORT_TSTR("mask_rcnn"), - ORT_TSTR("faster_rcnn"), - ORT_TSTR("fp16_shufflenet"), - ORT_TSTR("fp16_inception_v1"), - ORT_TSTR("fp16_tiny_yolov2"), - ORT_TSTR("tf_inception_v3"), - ORT_TSTR("tf_mobilenet_v1_1.0_224"), - ORT_TSTR("tf_mobilenet_v2_1.0_224"), - ORT_TSTR("tf_mobilenet_v2_1.4_224"), - ORT_TSTR("tf_resnet_v1_101"), - ORT_TSTR("tf_resnet_v1_152"), - ORT_TSTR("tf_resnet_v1_50"), - ORT_TSTR("tf_resnet_v2_101"), - ORT_TSTR("tf_resnet_v2_152"), - ORT_TSTR("tf_resnet_v2_50"), - ORT_TSTR("convtranspose_1d"), - ORT_TSTR("convtranspose_3d"), - ORT_TSTR("conv_with_strides_and_asymmetric_padding"), - ORT_TSTR("conv_with_strides_padding"), - ORT_TSTR("size") // INVALID_ARGUMENT: Cannot find binding of given name: x + ORT_TSTR("YOLOv3-12"), // needs to run symbolic shape inference shape first + ORT_TSTR("SSD-MobilenetV1-12"), // symbolic shape inference shape error + ORT_TSTR("SSD"), // needs to run symbolic shape inference shape first + ORT_TSTR("size") // INVALID_ARGUMENT: Cannot find binding of given name: x }; std::vector> paths; diff --git a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc index 37e0db906d054..48cd5ad99540a 100644 --- a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc @@ -15,6 +15,39 @@ namespace onnxruntime { namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// Returns a function that creates a graph with a QDQ Gather operator. +template +GetTestQDQModelFn BuildQDQGatherTestCase(const TestInputDef& input_def, + const TestInputDef& indices_def, + const std::vector& attrs, + bool use_contrib_qdq = false) { + return [input_def, indices_def, attrs, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // indices input + NodeArg* indices_input = MakeTestInput(builder, indices_def); + + // Gather op + NodeArg* gather_output = builder.MakeIntermediate(); + Node& gather_node = builder.AddNode("Gather", {input_qdq, indices_input}, {gather_output}); + + for (const auto& attr : attrs) { + gather_node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for Gather. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, gather_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + }; +} + // Test the accuracy of a QDQ Gather model on QNN EP. Checks if the QDQ model on QNN EP as accurate as the QDQ model on CPU EP // (compared to float32 model). template @@ -22,7 +55,8 @@ static void RunQDQGatherOpTest(const TestInputDef& input_def, const TestInputDef& indices_def, const std::vector& attrs, int opset, - ExpectedEPNodeAssignment expected_ep_assignment) { + ExpectedEPNodeAssignment expected_ep_assignment, + bool use_contrib_qdq = false) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -31,7 +65,8 @@ static void RunQDQGatherOpTest(const TestInputDef& input_def, #endif auto f32_model_builder = BuildOpTestCase("Gather", {input_def}, {indices_def}, attrs); - auto qdq_model_builder = BuildQDQOpTestCase("Gather", {input_def}, {indices_def}, attrs); + auto qdq_model_builder = BuildQDQGatherTestCase(input_def, indices_def, attrs, + use_contrib_qdq); TestQDQModelAccuracy(f32_model_builder, qdq_model_builder, @@ -52,6 +87,16 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt64_Axis0) { ExpectedEPNodeAssignment::All); } +// Test 16-bit QDQ Gather with static int64 indices with default axis. +TEST_F(QnnHTPBackendTests, GatherOp_U16_IndicesStaticInt64_Axis0) { + RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), + TestInputDef({2, 2}, true, {0, 1, 1, 2}), + {utils::MakeAttribute("axis", static_cast(0))}, + 13, + ExpectedEPNodeAssignment::All, + true); // Use 'com.microsoft' Q/DQ ops +} + // Tests that dynamic int64 indices are not supported on HTP backend. TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt64_Axis0) { RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index 3073dde9d8e4c..3da3dc858175b 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -142,11 +142,6 @@ TEST_F(QnnHTPBackendTests, MatMulOp_HTP_u8) { } // Test QDQ MatMul with 16-bit act, 8-bit weights (static) -// TODO: (SLIGHT) Inaccuracy detected for output 'output', element 0. -// Output quant params: scale=0.0015259021893143654, zero_point=0. -// Expected val: 98 -// QNN QDQ val: 97.720298767089844 (err 0.27970123291015625) -// CPU QDQ val: 97.726402282714844 (err 0.27359771728515625) TEST_F(QnnHTPBackendTests, MatMulOp_HTP_A16_W8Static) { std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; @@ -158,6 +153,40 @@ TEST_F(QnnHTPBackendTests, MatMulOp_HTP_A16_W8Static) { 7e-3f); } +// Test QDQ MatMul with uint16 activation uint16 weights, both dynamic +// Inaccuracy detected for output 'output_0', element 1. +// Output quant params: scale=0.0015259021893143654, zero_point=0. +// Expected val: 40 +// QNN QDQ val: 39.681087493896484 (err 0.31891250610351562) +// CPU QDQ val: 39.99847412109375 (err 0.00152587890625) +TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_HTP_A16_W16Dynamic) { + std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; + std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; + RunQDQMatMulOpOpTest(TestInputDef({2, 3}, false, input0_data), + TestInputDef({3, 2}, false, input1_data), + ExpectedEPNodeAssignment::All, + 18, + true, // Use com.microsoft Q/DQ ops + 7e-3f); +} + +// Test QDQ MatMul with uint16 activation uint16 weights, both dynamic +// Inaccuracy detected for output 'output_0', element 1. +// Output quant params: scale=0.71908456087112427, zero_point=1. +// Expected val: 46848.41015625 +// QNN QDQ val: 46844.04296875 (err 4.3671875) +// CPU QDQ val: 46848.359375 (err 0.05078125) +TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_HTP_A16_W16DynamicLarge) { + std::vector input0_data = GetFloatDataInRange(-10.0f, 10.0f, 12 * 96 * 512); + std::vector input1_data = GetFloatDataInRange(-10.0f, 10.0f, 12 * 96 * 512); + RunQDQMatMulOpOpTest(TestInputDef({1, 12, 96, 512}, false, input0_data), + TestInputDef({1, 12, 512, 96}, false, input1_data), + ExpectedEPNodeAssignment::All, + 18, + true, // Use com.microsoft Q/DQ ops + 7e-3f); +} + // Test 16-bit QDQ MatMul with static weights // TODO: Inaccuracy detected for output 'output', element 0. // Output quant params: scale=0.0015259021893143654, zero_point=0. diff --git a/onnxruntime/test/providers/qnn/where_htp_test.cc b/onnxruntime/test/providers/qnn/where_htp_test.cc index 49f3ef0fd983a..2d2aa23c28235 100644 --- a/onnxruntime/test/providers/qnn/where_htp_test.cc +++ b/onnxruntime/test/providers/qnn/where_htp_test.cc @@ -126,6 +126,7 @@ TEST_F(QnnHTPBackendTests, WhereLargeDataU8) { // QnnDsp graph prepare failed 13 // QnnDsp Failed to finalize graph QNN_4851394333842096633_1 with err: 1002 // QnnDsp Failed to finalize graph (id: 1) with err 1002 +// Worked with QNN v2.16 TEST_F(QnnHTPBackendTests, DISABLED_WhereLargeDataBroadcastU8) { RunWhereQDQTest(TestInputDef({5120}, false, false, true), TestInputDef({1, 16, 64, 5120}, true, 0.0f, 1.0f), @@ -133,6 +134,17 @@ TEST_F(QnnHTPBackendTests, DISABLED_WhereLargeDataBroadcastU8) { ExpectedEPNodeAssignment::All); } +// .\hexagon\prepare\seq\initial_sequencer_dp.cc:149:ERROR:A single op, +// "q::Broadcast" (Op ID: 19a200000012), requires 0xb40000 bytes of TCM, which is greater than the TCM size of 0x400000! +// .\hexagon\prepare\seq\initial_sequencer_dp.cc : 156 : ERROR : +// The name of the failing op before optimization is : "q::QNN_ElementWiseSelect"(Op ID : 12). +TEST_F(QnnHTPBackendTests, DISABLED_WhereLargeDataBroadcastTransformedU8) { + RunWhereQDQTest(TestInputDef({1, 1, 5120, 1}, false, false, true), + TestInputDef({1, 64, 5120, 16}, true, 0.0f, 1.0f), + TestInputDef({1, 1, 1, 1}, true, {3.0f}), + ExpectedEPNodeAssignment::All); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/python/onnxruntime_test_float8.py b/onnxruntime/test/python/onnxruntime_test_float8.py index 76ca5d9538374..bb63ea234498f 100644 --- a/onnxruntime/test/python/onnxruntime_test_float8.py +++ b/onnxruntime/test/python/onnxruntime_test_float8.py @@ -334,7 +334,7 @@ def test_model_cast_cast_cpu(self, name: str, float_name: str, saturate: int): ] ) @unittest.skipIf(not hasattr(TensorProto, "FLOAT8E4M3FN"), reason="needs onnx>=1.14.0") - @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running on CUDA.") + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_model_cast_cast_cuda(self, name: str, float_name: str, saturate: int, provider: str): so = onnxruntime.SessionOptions() so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL @@ -373,7 +373,7 @@ def test_model_cast_cast_cuda(self, name: str, float_name: str, saturate: int, p ] ) @unittest.skipIf(not hasattr(TensorProto, "FLOAT8E4M3FN"), reason="needs onnx>=1.14.0") - @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running on CUDA.") + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_model_cast_cast_cuda_ortvalue(self, name: str, float_name: str, saturate: int, provider: str): so = onnxruntime.SessionOptions() so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL @@ -627,7 +627,7 @@ def test_model_cast_like_x2_cpu(self, name: str, float_name: str, saturate: int) ] ) @unittest.skipIf(not hasattr(TensorProto, "FLOAT8E4M3FN"), reason="needs onnx>=1.14.0") - @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running on CUDA.") + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_model_qdq_cuda(self, name: str, float_name: str, saturate: int, provider: str): so = onnxruntime.SessionOptions() so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL @@ -693,7 +693,7 @@ def test_model_qdq_cuda_ortvalue(self, name: str, float_name: str, saturate: int self.assertEqual(expect.shape, y.shape) self.assertEqual(expect.dtype, y.dtype) - @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running on CUDA.") + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_compare_cpu_cuda_e4m3fn(self): folder = os.path.join(os.path.dirname(__file__), "..", "testdata", "float8") model = os.path.join(folder, "te.cast_fp8_1_fp32.onnx") diff --git a/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py index 784ae8ce70bd8..482a334b12b85 100644 --- a/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py +++ b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py @@ -14,10 +14,13 @@ from numpy.testing import assert_allclose from onnx import TensorProto from onnx.checker import check_model +from onnx.defs import onnx_opset_version from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info from onnx.numpy_helper import from_array -from onnxruntime import InferenceSession +from onnxruntime import InferenceSession, get_available_providers + +available_providers = [provider for provider in get_available_providers()] class TestFloat8Gemm8(unittest.TestCase): @@ -89,7 +92,10 @@ def get_model_gemm( ] nodes = [n for n in nodes if n is not None] graph = make_graph(nodes, "gemm", inputs, [d], inits) - onnx_model = make_model(graph, opset_imports=[make_opsetid("", 19)], ir_version=9) + opset_imports = [make_opsetid("", onnx_opset_version() - 1)] + if domain == "com.microsoft": + opset_imports.append(make_opsetid("com.microsoft", 1)) + onnx_model = make_model(graph, opset_imports=opset_imports, ir_version=9) if domain != "com.microsoft": check_model(onnx_model) return onnx_model @@ -192,21 +198,27 @@ def check(f): self.assertEqual(expected.shape, y.shape) self.assertEqual(expected.dtype, y.dtype) + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_model_gemm_float(self): self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3) + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_model_gemm_float_default_values(self): self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation=None) + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_model_gemm_float_relu(self): self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation="RELU") + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_model_gemm_float_gelu(self): self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation="GELU") + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_model_gemm_float_bias(self): self.common_test_model_gemm("FLOAT", transA=1, beta=1.0, rtol=1e-3) + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_model_gemm_float16(self): self.common_test_model_gemm( "FLOAT16", @@ -215,6 +227,8 @@ def test_model_gemm_float16(self): transB=1, ) + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") + @unittest.skipIf(not hasattr(TensorProto, "FLOAT8E4M3FN"), reason="needs onnx>=1.14.0") def test_model_gemm_float8_e4m3(self): self.common_test_model_gemm( "FLOAT8E4M3FN", @@ -226,6 +240,7 @@ def test_model_gemm_float8_e4m3(self): ) @parameterized.parameterized.expand(list(itertools.product([0, 1], [0, 1]))) + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_combinations_square_matrices(self, transA, transB): self.common_test_model_gemm("FLOAT", transA=transA, transB=transB, rtol=1e-3) @@ -237,6 +252,7 @@ def test_combinations_square_matrices(self, transA, transB): ((2, 3), (2, 5), 1, 0), ] ) + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running without CUDA.") def test_combinations(self, shapeA, shapeB, transA, transB): model = make_model( make_graph( @@ -256,7 +272,8 @@ def test_combinations(self, shapeA, shapeB, transA, transB): make_tensor_value_info("B", TensorProto.FLOAT, [None, None]), ], [make_tensor_value_info("Y", TensorProto.FLOAT, [None, None])], - ) + ), + opset_imports=[make_opsetid("", 19), make_opsetid("com.microsoft", 1)], ) sess = InferenceSession(model.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py deleted file mode 100644 index 4cf2e5d7f7588..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py +++ /dev/null @@ -1,1026 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import copy -import os -import unittest - -import numpy as np -import onnx -import torch -import torch.nn as nn -import torch.nn.functional as F -from helper import get_name -from numpy.testing import assert_allclose -from torchvision import datasets, transforms - -import onnxruntime -from onnxruntime.capi.ort_trainer import ( - IODescription, - LossScaler, - ModelDescription, - ORTTrainer, - generate_sample, - load_checkpoint, - save_checkpoint, -) - -SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) - - -def ort_trainer_learning_rate_description(): - return IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ) - - -def remove_extra_info(model_desc): - simple_model_desc = copy.deepcopy(model_desc) - for input_desc in simple_model_desc.inputs_: - input_desc.dtype_ = None - input_desc.num_classes_ = None - for output_desc in simple_model_desc.outputs_: - output_desc.dtype_ = None - output_desc.num_classes_ = None - return simple_model_desc - - -def bert_model_description(): - vocab_size = 30528 - input_ids_desc = IODescription( - "input_ids", - ["batch", "max_seq_len_in_batch"], - torch.int64, - num_classes=vocab_size, - ) - segment_ids_desc = IODescription("segment_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2) - input_mask_desc = IODescription("input_mask", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2) - masked_lm_labels_desc = IODescription( - "masked_lm_labels", - ["batch", "max_seq_len_in_batch"], - torch.int64, - num_classes=vocab_size, - ) - next_sentence_labels_desc = IODescription( - "next_sentence_labels", - [ - "batch", - ], - torch.int64, - num_classes=2, - ) - loss_desc = IODescription("loss", [], torch.float32) - - return ModelDescription( - [ - input_ids_desc, - segment_ids_desc, - input_mask_desc, - masked_lm_labels_desc, - next_sentence_labels_desc, - ], - [loss_desc], - ) - - -def map_optimizer_attributes(name): - no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] - no_decay = any(no_decay_key in name for no_decay_key in no_decay_keys) - if no_decay: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} - else: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6} - - -def generate_sample_batch(desc, batch_size, device): - desc_ = copy.deepcopy(desc) - desc_.shape_[0] = batch_size - sample = generate_sample(desc_, device) - return sample - - -def create_ort_trainer( - gradient_accumulation_steps, - use_mixed_precision, - allreduce_post_accumulation, - use_simple_model_desc=True, - loss_scaler=None, - deepspeed_zero_stage=0, -): - model_desc = bert_model_description() - simple_model_desc = remove_extra_info(model_desc) if use_simple_model_desc else model_desc - learning_rate_description = ort_trainer_learning_rate_description() - device = torch.device("cuda", 0) - - onnx_model = onnx.load(get_name("bert_toy_postprocessed.onnx")) - - model = ORTTrainer( - onnx_model, - None, - simple_model_desc, - "LambOptimizer", - map_optimizer_attributes, - learning_rate_description, - device, - gradient_accumulation_steps=gradient_accumulation_steps, - world_rank=0, - world_size=1, - loss_scaler=loss_scaler, - use_mixed_precision=use_mixed_precision, - allreduce_post_accumulation=allreduce_post_accumulation, - deepspeed_zero_stage=deepspeed_zero_stage, - ) - - return model, model_desc, device - - -def run_bert_training_test( - gradient_accumulation_steps, - use_mixed_precision, - allreduce_post_accumulation, - use_simple_model_desc=True, - use_internel_loss_scale=False, -): - torch.manual_seed(1) - onnxruntime.set_seed(1) - - loss_scaler = LossScaler("ort_test_input_loss_scalar", True) if use_internel_loss_scale else None - - model, model_desc, device = create_ort_trainer( - gradient_accumulation_steps, - use_mixed_precision, - allreduce_post_accumulation, - use_simple_model_desc, - loss_scaler, - ) - - if loss_scaler is None: - loss_scaler = LossScaler(model.loss_scale_input_name, True) - - input_ids_batches = [] - segment_ids_batches = [] - input_mask_batches = [] - masked_lm_labels_batches = [] - next_sentence_labels_batches = [] - batch_size = 16 - num_batches = 8 - for _batch in range(num_batches): - input_ids_batches = [ - *input_ids_batches, - generate_sample_batch(model_desc.inputs_[0], batch_size, device), - ] - segment_ids_batches = [ - *segment_ids_batches, - generate_sample_batch(model_desc.inputs_[1], batch_size, device), - ] - input_mask_batches = [ - *input_mask_batches, - generate_sample_batch(model_desc.inputs_[2], batch_size, device), - ] - masked_lm_labels_batches = [ - *masked_lm_labels_batches, - generate_sample_batch(model_desc.inputs_[3], batch_size, device), - ] - next_sentence_labels_batches = [ - *next_sentence_labels_batches, - generate_sample_batch(model_desc.inputs_[4], batch_size, device), - ] - - lr_batch_list = [ - 0.0000000e00, - 4.6012269e-07, - 9.2024538e-07, - 1.3803681e-06, - 1.8404908e-06, - 2.3006135e-06, - 2.7607362e-06, - 3.2208588e-06, - 3.6809815e-06, - ] - - actual_losses = [] - actual_all_finites = [] - - for batch_count in range(num_batches): - input_ids = generate_sample_batch(model_desc.inputs_[0], batch_size, device) - segment_ids = generate_sample_batch(model_desc.inputs_[1], batch_size, device) - input_mask = generate_sample_batch(model_desc.inputs_[2], batch_size, device) - masked_lm_labels = generate_sample_batch(model_desc.inputs_[3], batch_size, device) - next_sentence_labels = generate_sample_batch(model_desc.inputs_[4], batch_size, device) - lr = lr_batch_list[batch_count] - - learning_rate = torch.tensor([lr]).to(device) - training_args = [ - input_ids, - segment_ids, - input_mask, - masked_lm_labels, - next_sentence_labels, - learning_rate, - ] - if use_mixed_precision: - if not use_internel_loss_scale: - loss_scale = torch.tensor([loss_scaler.loss_scale_]).to(device) - training_args.append(loss_scale) - actual_loss = model.train_step(*training_args) - if isinstance(actual_loss, (list, tuple)): - assert len(actual_loss) == 2 - actual_loss, actual_all_finite = actual_loss - if not use_internel_loss_scale: - loss_scaler.update_loss_scale(actual_all_finite.item()) - actual_all_finites = [ - *actual_all_finites, - actual_all_finite.cpu().numpy().item(0), - ] - - actual_losses = [*actual_losses, actual_loss.cpu().numpy().item(0)] - else: - loss = model(*training_args) - actual_losses = [*actual_losses, loss.cpu().numpy().item(0)] - - if batch_count == num_batches - 1: - # test eval_step api with fetches at the end of the training. - # if eval_step is called during the training, it will affect the actual training loss (training session is stateful). - eval_loss = model.eval_step( - input_ids, - segment_ids, - input_mask, - masked_lm_labels, - next_sentence_labels, - fetches=["loss"], - ) - eval_loss = eval_loss.cpu().numpy().item(0) - - # If using internal loss scale, all_finites are handled internally too. - if use_mixed_precision and not use_internel_loss_scale: - return actual_losses, actual_all_finites, eval_loss - else: - return actual_losses, eval_loss - - -class MNISTWrapper: - class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - self.register_buffer("bias_buffer", torch.tensor(1e-6)) - - def forward(self, x): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - out = torch.add(out, self.bias_buffer.to(out.dtype)) - return out - - class NeuralNetWithLoss(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, x, target): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - return F.nll_loss(F.log_softmax(out, dim=1), target), out - - def my_loss(x, target): # noqa: N805 - return F.nll_loss(F.log_softmax(x, dim=1), target) - - def train_with_trainer(self, learningRate, trainer, device, train_loader, epoch): - actual_losses = [] - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - args_log_interval = 100 - if batch_idx % args_log_interval == 0: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, - batch_idx * len(data), - len(train_loader.dataset), - 100.0 * batch_idx / len(train_loader), - loss.item(), - ) - ) - actual_losses = [*actual_losses, loss.cpu().numpy().item()] - - return actual_losses - - # TODO: comple this once ORT training can do evaluation. - def test_with_trainer(self, trainer, device, test_loader): - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - output = F.log_softmax(trainer.eval_step((data), fetches=["probability"]), dim=1) - test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print( - "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( - test_loss, - correct, - len(test_loader.dataset), - 100.0 * correct / len(test_loader.dataset), - ) - ) - - return test_loss, correct / len(test_loader.dataset) - - def mnist_model_description(): - input_desc = IODescription("input1", ["batch", 784], torch.float32) - label_desc = IODescription( - "label", - [ - "batch", - ], - torch.int64, - num_classes=10, - ) - loss_desc = IODescription("loss", [], torch.float32) - probability_desc = IODescription("probability", ["batch", 10], torch.float32) - return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc]) - - def get_loaders(self): - args_batch_size = 64 - args_test_batch_size = 1000 - - kwargs = {"num_workers": 0, "pin_memory": True} - # set shuffle to False to get deterministic data set among different torch version - train_loader = torch.utils.data.DataLoader( - datasets.MNIST( - os.path.join(SCRIPT_DIR, "data"), - train=True, - download=True, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args_batch_size, - shuffle=False, - **kwargs, - ) - test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - os.path.join(SCRIPT_DIR, "data"), - train=False, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args_test_batch_size, - shuffle=False, - **kwargs, - ) - - return train_loader, test_loader - - def get_model(self): - input_size = 784 - hidden_size = 500 - num_classes = 10 - - # warning: changes the pytorch random generator state - model = MNISTWrapper.NeuralNet(input_size, hidden_size, num_classes) - model_desc = MNISTWrapper.mnist_model_description() - return model, model_desc - - def get_model_with_internal_loss(self): - input_size = 784 - hidden_size = 500 - num_classes = 10 - - # warning: changes the pytorch random generator state - model = MNISTWrapper.NeuralNetWithLoss(input_size, hidden_size, num_classes) - model_desc = MNISTWrapper.mnist_model_description() - return model, model_desc - - def get_trainer( - self, - model, - model_desc, - device, - onnx_opset_ver=12, - frozen_weights=[], # noqa: B006 - internal_loss_fn=False, - get_lr_this_step=None, - optimizer="SGDOptimizer", - ): - loss_fn = MNISTWrapper.my_loss if not internal_loss_fn else None - return ORTTrainer( - model, - loss_fn, - model_desc, - optimizer, - None, - IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ), - device, - _opset_version=onnx_opset_ver, - frozen_weights=frozen_weights, - get_lr_this_step=get_lr_this_step, - ) - - -class TestOrtTrainer(unittest.TestCase): - def run_mnist_training_and_testing(onnx_opset_ver): # noqa: N805 - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - trainer = mnist.get_trainer(model, model_desc, device, onnx_opset_ver=onnx_opset_ver) - - learningRate = 0.01 # noqa: N806 - args_epochs = 2 - expected_losses = [ - 2.312044143676758, - 0.8018650412559509, - 0.5819257497787476, - 0.47025489807128906, - 0.35800155997276306, - 0.41124576330184937, - 0.2731882333755493, - 0.4201386570930481, - 0.39458805322647095, - 0.38380366563796997, - 0.2722422480583191, - 0.24230478703975677, - 0.23505745828151703, - 0.33442264795303345, - 0.21140924096107483, - 0.31545233726501465, - 0.18556523323059082, - 0.3453553020954132, - 0.29598352313041687, - 0.3595045208930969, - ] - - expected_test_losses = [0.3145490005493164, 0.256188737487793] - expected_test_accuracies = [0.9075, 0.9265] - - actual_losses = [] - actual_test_losses, actual_accuracies = [], [] - for epoch in range(1, args_epochs + 1): - actual_losses = [ - *actual_losses, - *mnist.train_with_trainer(learningRate, trainer, device, train_loader, epoch), - ] - - test_loss, accuracy = mnist.test_with_trainer(trainer, device, test_loader) - actual_test_losses = [*actual_test_losses, test_loss] - actual_accuracies = [*actual_accuracies, accuracy] - - # if you update outcomes, also do so for resume from checkpoint test - # args_checkpoint_epoch = 1 - # if epoch == args_checkpoint_epoch: - # state = {'rng_state': torch.get_rng_state(), 'model': trainer.state_dict()} - # torch.save(state, get_name("ckpt_mnist.pt")) - - print("actual_losses=", actual_losses) - print("actual_test_losses=", actual_test_losses) - print("actual_accuracies=", actual_accuracies) - - # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs - # import pdb; pdb.set_trace() - rtol = 1e-03 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_test_losses, - actual_test_losses, - rtol=rtol, - err_msg="test loss mismatch", - ) - assert_allclose( - expected_test_accuracies, - actual_accuracies, - rtol=rtol, - err_msg="test accuracy mismatch", - ) - - def test_mnist_training_and_testing_opset12(self): - TestOrtTrainer.run_mnist_training_and_testing(onnx_opset_ver=12) - - def test_mnist_resume_training_and_testing(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - learningRate = 0.01 # noqa: N806 - args_epochs = 2 - args_checkpoint_epoch = 1 - # should match those in test without checkpointing - expected_losses = [ - 0.26509523391723633, - 0.24135658144950867, - 0.2397943139076233, - 0.3351520597934723, - 0.20998981595039368, - 0.31488314270973206, - 0.18481917679309845, - 0.34727591276168823, - 0.2971782684326172, - 0.3609251379966736, - ] - - expected_test_losses = [0.25632242965698243] - expected_test_accuracies = [0.9264] - - actual_losses = [] - actual_test_losses, actual_accuracies = [], [] - - # restore from checkpoint - resume_trainer = mnist.get_trainer(model, model_desc, device) - checkpoint = torch.load(get_name("ckpt_mnist.pt"), map_location="cpu") - torch.set_rng_state(checkpoint["rng_state"]) - resume_trainer.load_state_dict(checkpoint["model"], strict=True) - - # continue .. - for epoch in range(args_checkpoint_epoch + 1, args_epochs + 1): - actual_losses = [ - *actual_losses, - *mnist.train_with_trainer(learningRate, resume_trainer, device, train_loader, epoch), - ] - - test_loss, accuracy = mnist.test_with_trainer(resume_trainer, device, test_loader) - actual_test_losses = [*actual_test_losses, test_loss] - actual_accuracies = [*actual_accuracies, accuracy] - - print("actual_losses=", actual_losses) - print("actual_test_losses=", actual_test_losses) - print("actual_accuracies=", actual_accuracies) - - # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs - # import pdb; pdb.set_trace() - rtol = 1e-03 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_test_losses, - actual_test_losses, - rtol=rtol, - err_msg="test loss mismatch", - ) - assert_allclose( - expected_test_accuracies, - actual_accuracies, - rtol=rtol, - err_msg="test accuracy mismatch", - ) - - def test_mnist_state_dict(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device) - state_dict = trainer.state_dict() - assert state_dict == {} - - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - state_dict = trainer.state_dict() - assert state_dict.keys() == { - "fc1.bias", - "fc1.weight", - "fc2.bias", - "fc2.weight", - "bias_buffer", - } - - def test_mnist_save_as_onnx(self): - torch.manual_seed(1) - device = torch.device("cuda") - onnx_file_name = "mnist.onnx" - if os.path.exists(onnx_file_name): - os.remove(onnx_file_name) - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device) - trainer.save_as_onnx(onnx_file_name) - assert not os.path.exists(onnx_file_name) - - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - trainer.save_as_onnx(onnx_file_name) - assert os.path.exists(onnx_file_name) - - def test_mnist_device(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - for model_device in [torch.device("cpu"), torch.device("cuda")]: - model.to(model_device) - trainer = mnist.get_trainer(model, model_desc, device) - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - def test_mnist_initializer_names(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device) - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - assert ({n.name for n in trainer.onnx_model_.graph.initializer} - {"bias_buffer"}) == { - n for n, t in model.named_parameters() - } - - def test_mnist_initializer_names_with_internal_loss(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model_with_internal_loss() - - def get_lr_this_step(global_step): - learningRate = 0.02 # noqa: N806 - return torch.tensor([learningRate]) - - trainer = mnist.get_trainer( - model, - model_desc, - device, - internal_loss_fn=True, - get_lr_this_step=get_lr_this_step, - ) - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target) - - assert {n.name for n in trainer.onnx_model_.graph.initializer} == {n for n, t in model.named_parameters()} - - def test_mnist_frozen_weight(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device, frozen_weights=["fc1.weight"]) - - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - fc1_trainstep_1 = trainer.state_dict()["fc1.weight"] - fc2_trainstep_1 = trainer.state_dict()["fc2.weight"] - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - fc1_trainstep_2 = trainer.state_dict()["fc1.weight"] - fc2_trainstep_2 = trainer.state_dict()["fc2.weight"] - assert np.array_equal(fc1_trainstep_1, fc1_trainstep_2) and not np.array_equal(fc2_trainstep_1, fc2_trainstep_2) - - def test_mnist_torch_buffer(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device) - - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - fc1_trainstep_1 = trainer.state_dict()["fc1.weight"] - bias_buffer_trainstep_1 = trainer.state_dict()["bias_buffer"] - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - fc1_trainstep_2 = trainer.state_dict()["fc1.weight"] - bias_buffer_trainstep_2 = trainer.state_dict()["bias_buffer"] - assert not np.array_equal(fc1_trainstep_1, fc1_trainstep_2) and np.array_equal( - bias_buffer_trainstep_1, bias_buffer_trainstep_2 - ) - - def test_mnist_frozen_weight_checkpoint(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device, frozen_weights=["fc1.weight"]) - - learningRate = 0.02 # noqa: N806 - - # do one train step - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - # do one eval step - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.eval_step(data, target) - - # save checkpoint, load model and compare - state_dict = trainer.state_dict() - - new_model, _ = mnist.get_model() - trainer = mnist.get_trainer(new_model, model_desc, device, frozen_weights=["fc1.weight"]) - trainer.load_state_dict(state_dict) - - ckpt_loss, _ = trainer.eval_step(data, target) - assert loss == ckpt_loss - - loaded_state_dict = trainer.state_dict() - assert state_dict.keys() == loaded_state_dict.keys() - - def test_mnist_training_checkpoint(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer( - model, - model_desc, - device, - optimizer="LambOptimizer", - frozen_weights=["fc1.weight"], - ) - - learningRate = 0.02 # noqa: N806 - - # do 5 train step - for _i in range(5): - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - # do one eval step - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.eval_step(data, target) - - # save checkpoint, load model and compare - state_dict = trainer.state_dict() - - new_model, _ = mnist.get_model() - trainer = mnist.get_trainer( - new_model, - model_desc, - device, - optimizer="LambOptimizer", - frozen_weights=["fc1.weight"], - ) - trainer.load_state_dict(state_dict) - - ckpt_loss, _ = trainer.eval_step(data, target) - assert loss == ckpt_loss - - loaded_state_dict = trainer.state_dict() - assert state_dict.keys() == loaded_state_dict.keys() - for key in state_dict: - assert np.array_equal(state_dict[key], loaded_state_dict[key]) - - def test_bert_training_basic(self): - expected_losses = [ - 11.027887, - 11.108191, - 11.055356, - 11.040912, - 10.960277, - 11.02691, - 11.082471, - 10.920979, - ] - expected_eval_loss = [10.958977] - actual_losses, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=1, - use_mixed_precision=False, - allreduce_post_accumulation=False, - ) - - # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs - # print('losses expected: ', expected_losses) - # print('losses actual: ', actual_losses) - # print('eval_loss expected: ', expected_eval_loss) - # print('eval_loss actual: ', actual_eval_loss) - # import pdb; pdb.set_trace() - - rtol = 1e-03 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - def test_bert_training_gradient_accumulation(self): - expected_losses = [ - 11.027887, - 11.108191, - 11.055354, - 11.040904, - 10.960266, - 11.026897, - 11.082475, - 10.920998, - ] - expected_eval_loss = [10.958998] - - actual_losses, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=4, - use_mixed_precision=False, - allreduce_post_accumulation=False, - ) - - # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs - # print('losses expected: ', expected_losses) - # print('losses actual: ', actual_losses) - # print('eval_loss expected: ', expected_eval_loss) - # print('eval_loss actual: ', actual_eval_loss) - # import pdb; pdb.set_trace() - - rtol = 1e-03 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - def test_bert_checkpointing_basic(self): - model, _, _ = create_ort_trainer( - gradient_accumulation_steps=1, - use_mixed_precision=False, - allreduce_post_accumulation=True, - use_simple_model_desc=True, - loss_scaler=None, - ) - sd = model.state_dict() - - # modify one of the default values - sd["bert.encoder.layer.0.attention.output.LayerNorm.weight"] += 1 - model.load_state_dict(sd) - - ckpt_dir = "testdata" - save_checkpoint(model, ckpt_dir, "bert_toy_save_test") - del model - - # create new model - model2, _, _ = create_ort_trainer( - gradient_accumulation_steps=1, - use_mixed_precision=False, - allreduce_post_accumulation=True, - use_simple_model_desc=True, - loss_scaler=None, - ) - - # load changed checkpoint - load_checkpoint(model2, ckpt_dir, "bert_toy_save_test") - loaded_sd = model2.state_dict() - - for k, v in loaded_sd.items(): - assert torch.all(torch.eq(v, sd[k])) - - def test_wrap_model_loss_fn_state_dict(self): - torch.manual_seed(1) - device = torch.device("cuda") - - class LinearModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 4) - - def forward(self, y=None, x=None): - if y is not None: - return self.linear(x) + y - else: - return self.linear(x) + torch.ones(2, 4) - - pt_model = LinearModel() - data = torch.randn(2, 2) - label = torch.tensor([0, 1], dtype=torch.int64) - input_desc = IODescription("x", [2, 2], torch.float32) - label_desc = IODescription( - "label", - [ - 2, - ], - torch.int64, - num_classes=4, - ) - output_desc = IODescription("output", [2, 4], torch.float32) - loss_desc = IODescription("loss", [], torch.float32) - model_desc = ModelDescription([input_desc, label_desc], [loss_desc, output_desc]) - - def loss_fn(x, label): - return F.nll_loss(F.log_softmax(x, dim=1), label) - - def get_lr_this_step(global_step): - learningRate = 0.02 # noqa: N806 - return torch.tensor([learningRate]) - - ort_trainer = ORTTrainer( - pt_model, - loss_fn, - model_desc, - "SGDOptimizer", - None, - IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ), - device, - get_lr_this_step=get_lr_this_step, - ) - ort_trainer.train_step(x=data, label=label) - state_dict = ort_trainer.state_dict() - assert state_dict.keys() == {"linear.bias", "linear.weight"} - - -if __name__ == "__main__": - unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py deleted file mode 100644 index 3b994e6f26710..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import unittest - -from numpy.testing import assert_allclose, assert_array_equal -from onnxruntime_test_ort_trainer import run_bert_training_test - - -class TestOrtTrainer(unittest.TestCase): - def test_bert_training_mixed_precision(self): - expected_losses = [ - 11.034248352050781, - 11.125300407409668, - 11.006105422973633, - 11.047048568725586, - 11.027417182922363, - 11.015759468078613, - 11.060905456542969, - 10.971782684326172, - ] - expected_all_finites = [True, True, True, True, True, True, True, True] - expected_eval_loss = [10.959012985229492] - actual_losses, actual_all_finites, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=1, - use_mixed_precision=True, - allreduce_post_accumulation=False, - use_simple_model_desc=False, - ) - - rtol = 1e-02 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - def test_bert_training_mixed_precision_internal_loss_scale(self): - expected_losses = [ - 11.034248352050781, - 11.125300407409668, - 11.006105422973633, - 11.047048568725586, - 11.027417182922363, - 11.015759468078613, - 11.060905456542969, - 10.971782684326172, - ] - expected_eval_loss = [10.959012985229492] - actual_losses, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=1, - use_mixed_precision=True, - allreduce_post_accumulation=False, - use_simple_model_desc=False, - use_internel_loss_scale=True, - ) - - rtol = 1e-02 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - def test_bert_training_gradient_accumulation_mixed_precision(self): - expected_losses = [ - 11.034248352050781, - 11.125300407409668, - 11.006077766418457, - 11.047025680541992, - 11.027434349060059, - 11.0156831741333, - 11.060973167419434, - 10.971841812133789, - ] - expected_all_finites = [True, True] - expected_eval_loss = [10.95903205871582] - actual_losses, actual_all_finites, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=4, - use_mixed_precision=True, - allreduce_post_accumulation=False, - use_simple_model_desc=False, - ) - - rtol = 1e-02 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - -if __name__ == "__main__": - unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py b/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py deleted file mode 100644 index 540f39b797bdb..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import unittest - -import torch -import torch.nn as nn -from numpy.testing import assert_allclose -from onnxruntime_test_ort_trainer import map_optimizer_attributes, ort_trainer_learning_rate_description -from onnxruntime_test_training_unittest_utils import process_dropout - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer - - -class TestTrainingDropout(unittest.TestCase): - def setUp(self): - torch.manual_seed(1) - onnxruntime.set_seed(1) - - @unittest.skip( - "Temporarily disable this test. The graph below will trigger ORT to " - "sort backward graph before forward graph which gives incorrect result. " - "https://github.com/microsoft/onnxruntime/issues/16801" - ) - def test_training_and_eval_dropout(self): - class TwoDropoutNet(nn.Module): - def __init__(self, drop_prb_1, drop_prb_2, dim_size): - super().__init__() - self.drop_1 = nn.Dropout(drop_prb_1) - self.drop_2 = nn.Dropout(drop_prb_2) - self.weight_1 = torch.nn.Parameter(torch.zeros(dim_size, dtype=torch.float32)) - - def forward(self, x): - x = x + self.weight_1 - x = self.drop_1(x) - x = self.drop_2(x) - output = x - return output[0] - - dim_size = 3 - device = torch.device("cuda", 0) - # This will drop all values, therefore expecting all 0 in output tensor - model = TwoDropoutNet(0.999, 0.999, dim_size) - input_desc = IODescription("input", [dim_size], torch.float32) - output_desc = IODescription("output", [], torch.float32) - model_desc = ModelDescription([input_desc], [output_desc]) - lr_desc = ort_trainer_learning_rate_description() - model = ORTTrainer( - model, - None, - model_desc, - "LambOptimizer", - map_optimizer_attributes, - lr_desc, - device, - postprocess_model=process_dropout, - world_rank=0, - world_size=1, - ) - input = torch.ones(dim_size, dtype=torch.float32).to(device) - expected_training_output = [0.0] - expected_eval_output = [1.0] - learning_rate = torch.tensor([1.0000000e00]).to(device) - input_args = [input, learning_rate] - train_output = model.train_step(*input_args) - - rtol = 1e-04 - assert_allclose( - expected_training_output, - train_output.item(), - rtol=rtol, - err_msg="dropout training loss mismatch", - ) - - eval_output = model.eval_step(input) - assert_allclose( - expected_eval_output, - eval_output.item(), - rtol=rtol, - err_msg="dropout eval loss mismatch", - ) - - # Do another train step to make sure it's using original ratios - train_output_2 = model.train_step(*input_args) - assert_allclose( - expected_training_output, - train_output_2.item(), - rtol=rtol, - err_msg="dropout training loss 2 mismatch", - ) - - -if __name__ == "__main__": - unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py b/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py deleted file mode 100644 index 3d3feca06a99b..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py +++ /dev/null @@ -1,56 +0,0 @@ -import numpy as np -from onnx import numpy_helper - - -def get_node_index(model, node): - i = 0 - while i < len(model.graph.node): - if model.graph.node[i] == node: - break - i += 1 - return i if i < len(model.graph.node) else None - - -def add_const(model, name, output, t_value=None, f_value=None): - const_node = model.graph.node.add() - const_node.op_type = "Constant" - const_node.name = name - const_node.output.extend([output]) - attr = const_node.attribute.add() - attr.name = "value" - if t_value is not None: - attr.type = 4 - attr.t.CopyFrom(t_value) - else: - attr.type = 1 - attr.f = f_value - return const_node - - -def process_dropout(model): - dropouts = [] - index = 0 - for node in model.graph.node: - if node.op_type == "Dropout": - new_dropout = model.graph.node.add() - new_dropout.op_type = "TrainableDropout" - new_dropout.name = "TrainableDropout_%d" % index - # make ratio node - ratio = np.asarray([node.attribute[0].f], dtype=np.float32) - print(ratio.shape) - ratio_value = numpy_helper.from_array(ratio) - ratio_node = add_const( - model, - "dropout_node_ratio_%d" % index, - "dropout_node_ratio_%d" % index, - t_value=ratio_value, - ) - print(ratio_node) - new_dropout.input.extend([node.input[0], ratio_node.output[0]]) - new_dropout.output.extend(node.output) - dropouts.append(get_node_index(model, node)) - index += 1 - dropouts.sort(reverse=True) - for d in dropouts: - del model.graph.node[d] - model.opset_import[0].version = 10 diff --git a/onnxruntime/test/python/transformers/conformer_model_generator.py b/onnxruntime/test/python/transformers/conformer_model_generator.py new file mode 100644 index 0000000000000..71e4f2b63cf4f --- /dev/null +++ b/onnxruntime/test/python/transformers/conformer_model_generator.py @@ -0,0 +1,543 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from typing import List + +import numpy as np +import onnx +from bert_model_generator import float_tensor +from onnx import TensorProto, helper, numpy_helper + + +# Adapted from bert_model_generator.py +def get_tensor_and_weight(name: str, shape: List[int], random=False, zeros=False): + low = 0.0 + high = 1.0 + total_elements = 1 + for x in shape: + total_elements *= x + weights = ( + [np.random.uniform(low, high) for _ in range(total_elements)] + if random + else [0.0] * total_elements + if zeros + else [1.0] * total_elements + ) + return helper.make_tensor(name, TensorProto.FLOAT, shape, weights), weights + + +def create_conformer_attention( + hidden_size=512, + num_heads=8, + epsilon=0.000009999999747378752, + add_before_layernorm=False, + fused=False, +): + # Get head size and ensure head size is an integer + assert hidden_size % num_heads == 0 + head_size = hidden_size // num_heads + + # Construct input and output nodes + inputs = [ + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["batch_size", 8, 512]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, ["batch_size", 8, 512]), + helper.make_tensor_value_info("inp_cache_k", TensorProto.FLOAT, [24, "batch_size", 8, 72, head_size]), + helper.make_tensor_value_info("inp_cache_v", TensorProto.FLOAT, [24, "batch_size", 8, 72, head_size]), + ] + outputs = [ + helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", 8, hidden_size]), + helper.make_tensor_value_info("output_1", TensorProto.FLOAT, ["batch_size", 8, 512]), + helper.make_tensor_value_info("oup_cache_k", TensorProto.FLOAT, ["batch_size", 8, 80, 64]), + helper.make_tensor_value_info("oup_cache_v", TensorProto.FLOAT, ["batch_size", 8, 80, 64]), + ] + nodes = [] + + # Create layernorm (Add + LayerNorm or SkipLayerNorm) + if add_before_layernorm: + nodes.extend( + [ + helper.make_node( + "Add", ["input_0", "input_1"], ["layernorm_output_to_skiplayernorm"], "add_before_layernorm" + ), + helper.make_node( + "LayerNormalization", + ["layernorm_output_to_skiplayernorm", "layernorm_weight", "layernorm_bias"], + ["layernorm_add_output_to_matmul"], + "layernorm", + epsilon=epsilon, + ), + ] + ) + else: + nodes.append( + helper.make_node( + "SkipLayerNormalization", + ["input_0", "input_1", "layernorm_weight", "layernorm_bias"], + ["layernorm_add_output_to_matmul", "", "", "layernorm_add_output_to_skiplayernorm"], + "skiplayernorm", + domain="com.microsoft", + epsilon=epsilon, + ) + ) + + if fused: + fused_q_nodes = [ + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "q_weight"], + ["q_matmul_output"], + "q_path_matmul", + ), + helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), + helper.make_node( + "Reshape", ["q_add_output", "k_attn_heads_output"], ["q_4d_bsnh"], "q_reshape_to_4d", allowzero=0 + ), + helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), + helper.make_node( + "Div", + ["q_4d_bnsh", "q_scale"], + ["q_div_output"], + "q_div_by_sqrt_head_size", + ), + ] + nodes.extend(fused_q_nodes) + nodes.extend( + [ + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "k_weight"], + ["k_matmul_output"], + "k_path_matmul", + ), + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "v_weight"], + ["v_matmul_output"], + "v_path_matmul", + ), + helper.make_node( + "Reshape", ["q_div_output", "position_embed_output"], ["reshape_pos_emb"], "r_pos_emb", allowzero=0 + ), + helper.make_node( + "Transpose", ["reshape_pos_emb"], ["transpose_reshape_pos_emb"], "p_transpose", perm=[1, 0, 2] + ), + helper.make_node( + "MatMul", + ["transpose_reshape_pos_emb", "transpose_reshape_pos_emb"], + ["pos_matmul"], + "pos_embed_matmul", + ), + helper.make_node( + "Transpose", ["pos_matmul"], ["transpose_pos_matmul"], "p_matmul_transpose", perm=[1, 0, 2] + ), + helper.make_node( + "Reshape", + ["transpose_pos_matmul", "position_embed_output"], + ["reshape_position_emb"], + "final_reshape_pos_emb", + allowzero=0, + ), + helper.make_node( + "MultiHeadAttention", + [ + "q_matmul_output", + "k_matmul_output", + "v_matmul_output", + "Attention_0_qkv_bias", + "", + "reshape_position_emb", + "gather_past_k_output", + "gather_past_v_output", + ], + ["attn_output", "oup_cache_k", "oup_cache_v"], + "Attention_0", + domain="com.microsoft", + num_heads=num_heads, + ), + ] + ) + # Create nodes used with qkv concats, reshapes, and transposes + nodes.extend( + [ + helper.make_node("Shape", ["layernorm_add_output_to_matmul"], ["shape_output"], "shape", start=0), + helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), + helper.make_node( + "Mul", + ["gather_0_output", "num_heads_int"], + ["mul_attn_heads_output"], + "mul_num_heads", + ), + helper.make_node( + "Unsqueeze", + ["mul_attn_heads_output", "unsqueeze_axes_input"], + ["unsqueeze_position_embed"], + "unsqueeze_position_embed", + ), + helper.make_node( + "Concat", + ["unsqueeze_position_embed", "neg_one", "head_size"], + ["position_embed_output"], + "position_embed_concat_output", + axis=0, + ), + helper.make_node( + "Unsqueeze", + ["gather_0_output", "unsqueeze_axes_input"], + ["unsqueeze_attn_heads_output"], + "unsqueeze_num_heads", + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], + ["k_attn_heads_output"], + "k_num_heads", + axis=0, + ), + ] + ) + + nodes.extend( + [ + helper.make_node("Gather", ["inp_cache_v", "idx_0"], ["gather_past_v_output"], "gather_past_v", axis=0), + helper.make_node("Gather", ["inp_cache_k", "idx_0"], ["gather_past_k_output"], "gather_past_k", axis=0), + ] + ) + else: + # Create nodes for Q/K/V paths + q_nodes = [ + helper.make_node( + "MatMul", ["layernorm_add_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul" + ), + helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), + helper.make_node("Reshape", ["q_add_output", "q_attn_heads_output"], ["q_4d_bsnh"], "q_reshape_to_4d"), + helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), + helper.make_node( + "Div", + ["q_4d_bnsh", "q_scale"], + ["q_div_output"], + "q_div_by_sqrt_head_size", + ), + ] + k_nodes = [ + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "k_weight"], + ["k_matmul_output"], + "k_path_matmul", + ), + helper.make_node("Add", ["k_bias", "k_matmul_output"], ["k_add_output"], "k_path_add"), + helper.make_node("Reshape", ["k_add_output", "k_attn_heads_output"], ["k_4d_bsnh"], "k_reshape_to_4d"), + helper.make_node("Transpose", ["k_4d_bsnh"], ["k_4d_bnsh"], "k_transpose_to_bnsh", perm=[0, 2, 1, 3]), + helper.make_node( + "Concat", + ["gather_past_k_output", "k_4d_bnsh"], + ["oup_cache_k"], + "concat_past_k_and_curr_k", + axis=2, + ), + helper.make_node( + "Transpose", + ["oup_cache_k"], + ["k_output_transpose"], + "k_transpose_last_two_dims", + perm=[0, 1, 3, 2], + ), + ] + v_nodes = [ + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "v_weight"], + ["v_matmul_output"], + "v_path_matmul", + ), + helper.make_node("Add", ["v_bias", "v_matmul_output"], ["v_add_output"], "v_path_add"), + helper.make_node("Reshape", ["v_add_output", "v_attn_heads_output"], ["v_4d_bsnh"], "v_reshape_to_4d"), + helper.make_node("Transpose", ["v_4d_bsnh"], ["v_4d_bnsh"], "v_transpose_to_bnsh", perm=[0, 2, 1, 3]), + helper.make_node( + "Concat", + ["gather_past_v_output", "v_4d_bnsh"], + ["oup_cache_v"], + "concat_past_v_and_curr_v", + axis=2, + ), + ] + pos_embed = [ + helper.make_node("Reshape", ["q_div_output", "position_embed_output"], ["reshape_pos_emb"], "r_pos_emb"), + helper.make_node( + "Transpose", ["reshape_pos_emb"], ["transpose_reshape_pos_emb"], "p_transpose", perm=[1, 0, 2] + ), + helper.make_node( + "MatMul", + ["transpose_reshape_pos_emb", "transpose_reshape_pos_emb"], + ["pos_matmul"], + "pos_embed_matmul", + ), + helper.make_node( + "Transpose", ["pos_matmul"], ["transpose_pos_matmul"], "p_matmul_transpose", perm=[1, 0, 2] + ), + helper.make_node( + "Reshape", + ["transpose_pos_matmul", "position_embed_output"], + ["reshape_position_emb"], + "final_reshape_pos_emb", + ), + ] + nodes.extend(q_nodes) + nodes.extend(k_nodes) + nodes.extend(v_nodes) + nodes.extend(pos_embed) + + # Create nodes used with qkv concats, reshapes, and transposes + nodes.extend( + [ + helper.make_node("Shape", ["layernorm_add_output_to_matmul"], ["shape_output"], "shape", start=0), + helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), + helper.make_node( + "Mul", + ["gather_0_output", "num_heads_int"], + ["mul_attn_heads_output"], + "mul_num_heads", + ), + helper.make_node( + "Unsqueeze", + ["mul_attn_heads_output", "unsqueeze_axes_input"], + ["unsqueeze_position_embed"], + "unsqueeze_position_embed", + ), + helper.make_node( + "Concat", + ["unsqueeze_position_embed", "neg_one", "head_size"], + ["position_embed_output"], + "position_embed_concat_output", + axis=0, + ), + helper.make_node( + "Unsqueeze", + ["gather_0_output", "unsqueeze_axes_input"], + ["unsqueeze_attn_heads_output"], + "unsqueeze_num_heads", + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], + ["q_attn_heads_output"], + "q_num_heads", + axis=0, + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], + ["k_attn_heads_output"], + "k_num_heads", + axis=0, + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], + ["v_attn_heads_output"], + "v_num_heads", + axis=0, + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size"], + ["bsd_format"], + axis=0, + ), + helper.make_node( + "Constant", + inputs=[], + outputs=["q_bsnh_reshape"], + value=numpy_helper.from_array( + np.array([0, 0, num_heads, head_size], dtype="int64"), name="const_tensor" + ), + ), + ] + ) + + nodes.extend( + [ + helper.make_node("Gather", ["inp_cache_v", "idx_0"], ["gather_past_v_output"], "gather_past_v", axis=0), + helper.make_node("Gather", ["inp_cache_k", "idx_0"], ["gather_past_k_output"], "gather_past_k", axis=0), + ] + ) + + # Compute Q x K' + nodes.extend( + [ + helper.make_node( + "MatMul", + [ + "q_div_output", + "k_output_transpose", + ], + ["qk_output"], + "matmul_qk", + ) + ] + ) + + # Create nodes for computing softmax(Q x K') x V + nodes.extend( + [ + helper.make_node( + "Add", + [ + "qk_output", + "reshape_position_emb", + ], + ["add_qk_output"], + "add_qk", + ), + helper.make_node( + "Softmax", + ["add_qk_output"], + ["softmax_output"], + "softmax_qk", + axis=2, + ), + helper.make_node( + "MatMul", + ["softmax_output", "oup_cache_v"], + ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], + "matmul_qkv", + ), + helper.make_node( + "Transpose", + ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], + ["qkv_bsnh"], + "transpose_bnsh_to_bsnh", + perm=[0, 2, 1, 3], + ), + helper.make_node("Reshape", ["qkv_bsnh", "bsd_format"], ["attn_output"], "qkv_bsd"), + ] + ) + + # Create final nodes to conclude attention + nodes.append( + helper.make_node( + "MatMul", + ["attn_output", "matmul_after_attn_initializer"], + ["matmul_after_attn_output"], + "matmul_after_attn", + ), + ) + if not fused: + next_sln_inputs = [ + "layernorm_add_output_to_skiplayernorm", + "add_after_attn_output", + "layernorm_weight", + "layernorm_bias", + ] + nodes.extend( + [ + helper.make_node( + "Add", + ["add_after_attn_initializer", "matmul_after_attn_output"], + ["add_after_attn_output"], + "add_after_attn", + ), + helper.make_node( + "SkipLayerNormalization", + next_sln_inputs, + ["output_0", "", "", "output_1"], + "next_skiplayernorm", + domain="com.microsoft", + epsilon=epsilon, + ), + ] + ) + else: + next_sln_inputs = [ + "matmul_after_attn_output", + "layernorm_add_output_to_skiplayernorm", + "layernorm_weight", + "layernorm_bias", + "add_after_attn_initializer", + ] + nodes.append( + helper.make_node( + "SkipLayerNormalization", + next_sln_inputs, + ["output_0", "", "", "output_1"], + "SkipLayerNorm_AddBias_0", + domain="com.microsoft", + epsilon=epsilon, + ) + ) + + # Create initializers + v_weight, v_weight_data = get_tensor_and_weight("v_weight", [hidden_size, hidden_size]) + v_bias, v_bias_data = get_tensor_and_weight("v_bias", [hidden_size]) + q_weight, q_weight_data = get_tensor_and_weight("q_weight", [hidden_size, hidden_size]) + q_bias, q_bias_data = get_tensor_and_weight("q_bias", [hidden_size]) + k_weight, k_weight_data = get_tensor_and_weight("k_weight", [hidden_size, hidden_size]) + k_bias, k_bias_data = get_tensor_and_weight("k_bias", [hidden_size]) + + qkv_bias = helper.make_tensor( + "Attention_0_qkv_bias", + TensorProto.FLOAT, + [3 * hidden_size], + q_bias_data + k_bias_data + v_bias_data, + ) + initializers = [ + float_tensor("layernorm_weight", [hidden_size]), + float_tensor("layernorm_bias", [hidden_size]), + float_tensor("matmul_after_attn_initializer", [hidden_size, hidden_size]), + float_tensor("add_after_attn_initializer", [hidden_size]), + ] + + # Add Q/K/V weight tensors as initializers + if fused: + initializers.extend([q_weight, k_weight, v_weight]) + initializers.extend([q_bias]) + initializers.append(qkv_bias) + initializers.extend( + [ + numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), + numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), + numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), + numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), + numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), + numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), + numpy_helper.from_array(np.array([0, 0, num_heads, head_size], dtype="int64"), name="q_bsnh_reshape"), + ] + ) + else: + initializers.extend([q_weight, k_weight, v_weight]) + + initializers.extend([q_bias, k_bias, v_bias]) + + initializers.extend( + [ + numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), + numpy_helper.from_array(np.array([num_heads], dtype="int64"), name="num_heads"), + numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), + numpy_helper.from_array(np.array([hidden_size], dtype="int64"), name="hidden_size"), + numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), + numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), + numpy_helper.from_array(np.array(1, dtype="int64"), name="idx_1"), + numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), + numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), + ] + ) + + # Construct graph + graph = helper.make_graph(nodes, "conformer_self_mha_graph", inputs, outputs, initializers, doc_string="conformer") + opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) + return helper.make_model(graph, opset_imports=(opsetid,)) + + +if __name__ == "__main__": + np.random.seed(2) + num_heads = 8 + hidden_size = 512 + + model = create_conformer_attention(num_heads=num_heads, hidden_size=hidden_size) + onnx.save(model, "conformer_self_mha.onnx") + + model = create_conformer_attention(num_heads=num_heads, hidden_size=hidden_size, fused=True) + onnx.save(model, "./test_data/models/conformer/conformer_self_mha_fused.onnx") diff --git a/onnxruntime/test/python/transformers/test_conformer.py b/onnxruntime/test/python/transformers/test_conformer.py new file mode 100644 index 0000000000000..471ba9756bcf8 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_conformer.py @@ -0,0 +1,69 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import unittest + +import onnx +from conformer_model_generator import create_conformer_attention +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from fusion_options import FusionOptions + from onnx_model import OnnxModel + from optimizer import optimize_model +else: + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.optimizer import optimize_model + + +class TestFusion(unittest.TestCase): + def verify_fusion(self, optimized_model, expected_model_filename): + optimized_model.topological_sort(is_deterministic=True) + + expected_model_path = os.path.join( + os.path.dirname(__file__), "test_data", "models", "conformer", expected_model_filename + ) + print("Expected model path = ", expected_model_path) + expected_model = OnnxModel(onnx.load(expected_model_path)) + expected_model.topological_sort(is_deterministic=True) + + nodes = optimized_model.model.graph.node + self.assertEqual(len(nodes), len(expected_model.model.graph.node)) + + for i in range(len(nodes)): + self.assertEqual(nodes[i], expected_model.model.graph.node[i]) + + for expected_initializer in expected_model.model.graph.initializer: + print("Expected initializer initial = ", expected_initializer.name) + self.assertTrue( + OnnxModel.has_same_value( + optimized_model.get_initializer(expected_initializer.name), expected_initializer + ) + ) + + def test_ct_mha_fusion(self): + num_heads = 8 + hidden_size = 512 + model = create_conformer_attention(num_heads=num_heads, hidden_size=hidden_size, add_before_layernorm=False) + dir = "." + model_path = os.path.join(dir, "conformer_self_mha.onnx") + onnx.save(model, model_path) + options = FusionOptions("conformer") + optimized_model = optimize_model( + model_path, + model_type="conformer", + num_heads=num_heads, + hidden_size=hidden_size, + optimization_options=options, + ) + os.remove(model_path) + self.verify_fusion(optimized_model, "conformer_self_mha_fused.onnx") + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_data/models/conformer/conformer_self_mha_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/conformer/conformer_self_mha_fused.onnx new file mode 100644 index 0000000000000..9d882751db265 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/conformer/conformer_self_mha_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index 99f62ffdb9f53..8a839875de2a2 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -183,7 +183,9 @@ def create_multihead_attention_graph(config): return model.SerializeToString() -def create_group_query_attention_graph_prompt(config, past_kv_format=Formats.BSNH, share_buffer=True): +def create_group_query_attention_graph_prompt( + config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1 +): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length nodes = [ @@ -202,6 +204,7 @@ def create_group_query_attention_graph_prompt(config, past_kv_format=Formats.BSN "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, + local_window_size=local_window_size, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -297,6 +300,26 @@ def create_group_query_attention_graph_prompt(config, past_kv_format=Formats.BSN config.head_size, ], ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + ], + ), ] graph = helper.make_graph( @@ -310,7 +333,9 @@ def create_group_query_attention_graph_prompt(config, past_kv_format=Formats.BSN return model.SerializeToString() -def create_group_query_attention_graph_past(config, past_kv_format=Formats.BSNH, share_buffer=True): +def create_group_query_attention_graph_past( + config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1 +): past_kv_seqlen = config.kv_sequence_length present_kv_seqlen = ( config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length @@ -331,6 +356,7 @@ def create_group_query_attention_graph_past(config, past_kv_format=Formats.BSNH, "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, + local_window_size=local_window_size, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -636,8 +662,12 @@ def mha_func(q, k, v, config): return output -def gqa_prompt_func(q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True): - onnx_model_str = create_group_query_attention_graph_prompt(config, past_kv_format, share_buffer) +def gqa_prompt_func( + q, k, v, config, new_k, new_v, seqlens_k=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True +): + onnx_model_str = create_group_query_attention_graph_prompt( + config, past_kv_format, share_buffer, local_window_size=window_size + ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None past_v = v.clone() if share_buffer else None @@ -706,8 +736,12 @@ def gqa_prompt_func(q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_forma return output, present_k, present_v -def gqa_past_func(q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True): - onnx_model_str = create_group_query_attention_graph_past(config, past_kv_format, share_buffer) +def gqa_past_func( + q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1 +): + onnx_model_str = create_group_query_attention_graph_past( + config, past_kv_format, share_buffer, local_window_size=window_size + ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() past_v = v.clone() @@ -796,6 +830,28 @@ def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_paddi return col_idx > row_idx + sk - sq +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + + def attention_ref( q, k, @@ -805,6 +861,7 @@ def attention_ref( dropout_p=0.0, dropout_mask=None, causal=False, + window_size=(-1, -1), # -1 means infinite window size upcast=True, reorder_ops=False, ): @@ -817,6 +874,8 @@ def attention_ref( key_padding_mask: (batch_size, seqlen_k) dropout_p: float dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast output back to fp16/bf16. reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) @@ -826,6 +885,8 @@ def attention_ref( output: (batch_size, seqlen_q, nheads, head_dim) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ + if causal: + window_size = (window_size[0], 0) dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() @@ -839,12 +900,24 @@ def attention_ref( scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - if causal: - causal_mask = construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device) - scores.masked_fill_(causal_mask, float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) attention = torch.softmax(scores, dim=-1) - if causal: # Some rows are completely masked out so we fill them with zero instead of NaN - attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) if dropout_mask is not None: attention_drop = attention.masked_fill(~dropout_mask, 0.0) @@ -853,7 +926,6 @@ def attention_ref( output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) @@ -957,6 +1029,8 @@ def parity_check_mha( def parity_check_gqa_prompt( config, + causal=False, + local=False, past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, @@ -1007,6 +1081,15 @@ def parity_check_gqa_prompt( requires_grad=False, ) + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(0, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) + # Pytorch to compare k_cache_ref = k.clone() v_cache_ref = v.clone() @@ -1033,14 +1116,18 @@ def parity_check_gqa_prompt( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) + out_ref, _ = attention_ref( + q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_prompt_func(q, k, v, config, new_k, new_v, cache_seqlens, past_format, True) + out, present_k, present_v = gqa_prompt_func( + q, k, v, config, new_k, new_v, cache_seqlens, left_window_size, past_format, True + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1052,6 +1139,10 @@ def parity_check_gqa_prompt( # Compare results print( "KV-buffer", + " causal:", + causal, + " local:", + local, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1080,6 +1171,8 @@ def parity_check_gqa_prompt( def parity_check_gqa_prompt_no_buff( config, + causal=False, + local=False, past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, @@ -1112,6 +1205,15 @@ def parity_check_gqa_prompt_no_buff( requires_grad=False, ) + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(0, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) + # Pytorch to compare k_cache_ref = new_k.clone() v_cache_ref = new_v.clone() @@ -1132,14 +1234,18 @@ def parity_check_gqa_prompt_no_buff( new_mask = brange < cache_seqlens_expanded k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True) + out_ref, _ = attention_ref( + q, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size + ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_prompt_func(q, None, None, config, new_k, new_v, cache_seqlens, past_format, False) + out, present_k, present_v = gqa_prompt_func( + q, None, None, config, new_k, new_v, cache_seqlens, left_window_size, past_format, False + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1179,6 +1285,8 @@ def parity_check_gqa_prompt_no_buff( def parity_check_gqa_past( config, + causal=False, + local=False, past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, @@ -1228,6 +1336,14 @@ def parity_check_gqa_past( dtype=torch.float16, requires_grad=False, ) + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(0, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) # Pytorch to compare k_cache_ref = k.clone() @@ -1253,14 +1369,18 @@ def parity_check_gqa_past( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) + out_ref, _ = attention_ref( + q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, cache_seqlens, past_format, True) + out, present_k, present_v = gqa_past_func( + q, k, v, config, new_k, new_v, cache_seqlens, past_format, True, left_window_size + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1274,6 +1394,10 @@ def parity_check_gqa_past( "KV-buffer", "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", + " causal:", + causal, + " local:", + local, " B:", config.batch_size, " S:", @@ -1300,6 +1424,8 @@ def parity_check_gqa_past( def parity_check_gqa_past_no_buff( config, + causal=False, + local=False, past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, @@ -1351,6 +1477,15 @@ def parity_check_gqa_past_no_buff( requires_grad=False, ) + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(0, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) + # Pytorch to compare k_cache_ref = k.clone() v_cache_ref = v.clone() @@ -1378,14 +1513,18 @@ def parity_check_gqa_past_no_buff( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) + out_ref, _ = attention_ref( + q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, cache_seqlens, past_format, False) + out, present_k, present_v = gqa_past_func( + q, k, v, config, new_k, new_v, cache_seqlens, past_format, False, window_size=left_window_size + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1401,142 +1540,10 @@ def parity_check_gqa_past_no_buff( # Compare results print( "NO buff", - "past kv format:", - "BSNH" if past_format == Formats.BSNH else "BNSH", - " B:", - config.batch_size, - " S:", - config.sequence_length, - " kv S:", - config.kv_sequence_length, - " N:", - config.num_heads, - " kv N:", - config.kv_num_heads, - " h:", - config.head_size, - " Mean Error:", - numpy.mean(numpy.abs(out - out_ref)), - numpy.allclose( - out, - out_ref, - rtol=rtol, - atol=atol, - equal_nan=True, - ), - ) - - -def parity_check_gqa_past_no_buff_no_mask( - config, - past_format=Formats.BSNH, - rtol=1e-3, - atol=1e-3, -): - q = torch.randn( - config.batch_size, - config.sequence_length, - config.num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - k = torch.randn( - config.batch_size, - config.past_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.past_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - v = torch.randn( - config.batch_size, - config.past_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.past_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_k = torch.randn( - config.batch_size, - config.sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_v = torch.randn( - config.batch_size, - config.sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - - # Pytorch to compare - k_cache_ref = k.clone() - v_cache_ref = v.clone() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - k_cache_ref = torch.cat((k_cache_ref, new_k), 1) - v_cache_ref = torch.cat((v_cache_ref, new_v), 1) - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - key_padding_mask = None - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) - out_ref = out_ref.detach().cpu().numpy() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - - # Flash function - out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, past_format, False) - out = torch.squeeze(out, 0) - out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) - out = out.detach().cpu().numpy() - - # Make sure past-present buffer updating correctly - if past_format == Formats.BSNH: - assert numpy.allclose( - present_k, - k_cache_ref.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - ) - assert numpy.allclose( - present_v, - v_cache_ref.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - ) - else: - assert numpy.allclose( - present_k, - k_cache_ref.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - ) - assert numpy.allclose( - present_v, - v_cache_ref.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - ) - - # Compare results - print( - "Unbuffered", + " causal:", + causal, + " local:", + local, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1663,10 +1670,11 @@ def test_gqa_no_past(self): for sq, skv in seqs: for n, n2 in num_h: for h in h_sizes: - for past_kv_format in [Formats.BNSH]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - parity_check_gqa_prompt(config, past_format=past_kv_format) - parity_check_gqa_prompt_no_buff(config, past_format=past_kv_format) + for local in [False, True]: + for past_kv_format in [Formats.BNSH]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + parity_check_gqa_prompt(config, local=local, past_format=past_kv_format) + parity_check_gqa_prompt_no_buff(config, local=local, past_format=past_kv_format) def test_gqa_past(self): if not torch.cuda.is_available(): @@ -1725,24 +1733,25 @@ def test_gqa_past(self): for s, s2 in seqs: for n, n2 in num_h: for h in h_sizes: - for past_kv_format in [Formats.BNSH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + for local in [False, True]: + for past_kv_format in [Formats.BNSH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) if __name__ == "__main__": unittest.main() - # test_gqa = TestGQA() - # test_gqa.test_gqa_past() diff --git a/onnxruntime/test/python/transformers/test_parity_moe.py b/onnxruntime/test/python/transformers/test_parity_moe.py new file mode 100644 index 0000000000000..72ca5d9975c05 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_parity_moe.py @@ -0,0 +1,431 @@ +# -------------------------------------------------------------------------- +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +import time +import unittest + +import numpy +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from onnx import TensorProto, helper + +import onnxruntime + +torch.manual_seed(42) +numpy.random.seed(42) + + +ORT_DTYPE = TensorProto.FLOAT16 +NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 +THRESHOLD = 3e-2 + + +def value_string_of(numpy_array): + arr = numpy_array.flatten() + lines = ["f, ".join([str(v) for v in arr[i : min(i + 8, arr.size)]]) for i in range(0, arr.size, 8)] + return "{\n " + "f,\n ".join(lines) + "f}" + + +def print_tensor(name, numpy_array): + print(f"const std::vector {name} = {value_string_of(numpy_array)};") + + +def create_moe_onnx_graph( + num_rows, + num_experts, + hidden_size, + inter_size, + fc1_experts_weights, + fc2_experts_weights, + fc1_experts_bias, + fc2_experts_bias, +): + nodes = [ + helper.make_node( + "MoE", + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc2_experts_weights", + "fc1_experts_bias", + "fc2_experts_bias", + ], + ["output"], + "MoE_0", + k=1, + activation_type="gelu", + domain="com.microsoft", + ), + ] + + fc1_shape = [num_experts, hidden_size, inter_size] + fc2_shape = [num_experts, inter_size, hidden_size] + + torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 + + initializers = [ + helper.make_tensor( + "fc1_experts_weights", + ORT_DTYPE, + fc1_shape, + fc1_experts_weights.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_weights", + ORT_DTYPE, + fc2_shape, + fc2_experts_weights.to(torch_type).flatten().tolist(), + raw=False, + ), + ] + + fc1_bias_shape = [num_experts, inter_size] + fc2_bias_shape = [num_experts, hidden_size] + initializers.extend( + [ + helper.make_tensor( + "fc1_experts_bias", + ORT_DTYPE, + fc1_bias_shape, + fc1_experts_bias.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_bias", + ORT_DTYPE, + fc2_bias_shape, + fc2_experts_bias.to(torch_type).flatten().tolist(), + raw=False, + ), + ] + ) + + graph_inputs = [ + helper.make_tensor_value_info("input", ORT_DTYPE, [num_rows, hidden_size]), + ] + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + ORT_DTYPE, + [num_rows, num_experts], + ) + ) + + graph_outputs = [ + helper.make_tensor_value_info("output", ORT_DTYPE, [num_rows, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def get_activation_fn(activation): + if activation == "relu": + return nn.ReLU + elif activation == "gelu": + return nn.GELU + else: + raise NotImplementedError + + +class MoEGate(nn.Module): + def __init__(self, num_experts, in_features): + super().__init__() + self.wg_reduction = torch.nn.Linear(in_features, 16, bias=False) + + wg = torch.empty(num_experts, 16) + torch.nn.init.orthogonal_(wg, gain=0.32) + self.register_parameter("wg", torch.nn.Parameter(wg)) + + def forward(self, input): + input = self.wg_reduction(input) + with torch.no_grad(): + wg_norm = self.wg.norm(p=2.0, dim=1, keepdim=True) + self.wg.mul_(1.5 / wg_norm) + logits = self._cosine(input, self.wg) + return logits + + def _cosine(self, mat1, mat2, eps=1e-4): + assert mat1.dim() == 2 + assert mat2.dim() == 2 + + mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps) + return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1) + + +class MoERuntimeExperts(nn.Module): + def __init__( + self, + num_experts, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + bias=True, + chunk_size=-1, + ): + super().__init__() + # assert bias is False, "Current bias is not supported" + assert drop == 0.0, "Current drop is not supported" + assert chunk_size == -1, "Current chunk is not supported" + + self.weight1 = nn.Parameter(torch.rand(num_experts, in_features, hidden_features)) + self.weight2 = nn.Parameter(torch.rand(num_experts, hidden_features, out_features)) + + self.bias1 = nn.Parameter(torch.rand(num_experts, hidden_features)) if bias else None + self.bias2 = nn.Parameter(torch.rand(num_experts, in_features)) if bias else None + + self.act = act_layer() + + def forward(self, x, indices_s): + x = x.unsqueeze(1) + x = self.bmm(x, self.weight1, indices_s) + if self.bias1 is not None: + x = x + self.bias1[indices_s].unsqueeze(1) # S x hidden_features + x = self.act(x) + x = self.bmm(x, self.weight2, indices_s) + if self.bias2 is not None: + x = x + self.bias2[indices_s].unsqueeze(1) # S x 1 x in_features + return x + + def bmm(self, x, weight, indices_s): + x = torch.bmm(x, weight[indices_s]) # S x 1 x hidden_features + return x + + +class MoE(nn.Module): + def __init__( + self, + batch_size, + num_rows, + num_experts, + in_features, + hidden_features=None, + out_features=None, + eval_capacity=-1, + activation="gelu", + ): + super().__init__() + self.num_experts = num_experts + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.eval_capacity = eval_capacity # -1 means we route all tokens + + self.gate = MoEGate(num_experts=num_experts, in_features=in_features) + self.moe_experts = MoERuntimeExperts( + num_experts=num_experts, + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + act_layer=get_activation_fn(activation), + bias=True, + ) + + self.moe_onnx_graph = create_moe_onnx_graph( + batch_size * num_rows, + num_experts, + in_features, + hidden_features, + self.moe_experts.weight1, + self.moe_experts.weight2, + self.moe_experts.bias1, + self.moe_experts.bias2, + ) + + self.ort_sess = self.create_ort_session() + + self.torch_input = torch.randn(batch_size, num_rows, in_features) + + def create_ort_session(self): + from onnxruntime import InferenceSession, SessionOptions + + sess_options = SessionOptions() + + cuda_providers = ["CUDAExecutionProvider"] + if cuda_providers[0] not in onnxruntime.get_available_providers(): + return None + + sess_options.log_severity_level = 2 + ort_session = InferenceSession(self.moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"]) + + return ort_session + + def ort_run_with_iobinding(self, ort_inputs, repeat=1000): + iobinding = self.ort_sess.io_binding() + device_id = torch.cuda.current_device() + + iobinding.bind_input( + name="input", + device_type="cuda", + device_id=device_id, + element_type=NP_TYPE, + shape=ort_inputs["input"].shape, + buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(ort_inputs["input"], "cuda", device_id).data_ptr(), + ) + iobinding.bind_input( + name="router_probs", + device_type="cuda", + device_id=device_id, + element_type=NP_TYPE, + shape=ort_inputs["router_probs"].shape, + buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( + ort_inputs["router_probs"], "cuda", device_id + ).data_ptr(), + ) + + iobinding.synchronize_inputs() + + iobinding.bind_output( + name="output", + device_type="cuda", + device_id=device_id, + element_type=NP_TYPE, + shape=ort_inputs["input"].shape, + buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( + numpy.zeros(ort_inputs["input"].shape), "cuda", device_id + ).data_ptr(), + ) + iobinding.synchronize_outputs() + + s = time.time() + for _ in range(repeat): + self.ort_sess.run_with_iobinding(iobinding) + e = time.time() + print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms") + + def torch_forward(self): + x = self.torch_input + + b, t, c = x.shape + x = x.reshape(-1, c) + logits = self.gate(x) + gates = torch.nn.functional.softmax(logits, dim=1) + ret = torch.max(gates, dim=1) + indices_s = ret.indices # dim: [bs], the index of the expert with highest softmax value + scores = ret.values.unsqueeze(-1).unsqueeze(-1) # S + x = self.moe_experts(x, indices_s) + + x = x * scores + x = x.reshape(b * t, c) + + return x, torch.sum(x) + + def onnx_forward(self, iobinding=False): + x = self.torch_input + + _, _, c = x.shape + y = x.reshape(-1, c) + logits = self.gate(y) + + ort_inputs = { + "input": numpy.ascontiguousarray(y.detach().numpy().astype(NP_TYPE)), + "router_probs": numpy.ascontiguousarray(logits.detach().numpy().astype(NP_TYPE)), + } + + ort_output = None + if self.ort_sess is not None: + if not iobinding: + ort_output = self.ort_sess.run(None, ort_inputs) + else: + self.ort_run_with_iobinding(ort_inputs) + return None + + # print_tensor("input", ort_inputs["input"]) + # print_tensor("router_probs", ort_inputs["router_probs"]) + # print_tensor("fc1_experts_weights", self.moe_experts.weight1.detach().numpy()) + # print_tensor("fc2_experts_weights", self.moe_experts.weight2.detach().numpy()) + # print_tensor("fc1_experts_bias", self.moe_experts.bias1.detach().numpy()) + # print_tensor("fc2_experts_bias", self.moe_experts.bias2.detach().numpy()) + # print_tensor("output", ort_output[0]) + + return ort_output + + def parity_check(self): + torch_out = self.torch_forward() + ort_out = self.onnx_forward() + if ort_out is not None: + # print("max diff", numpy.max(numpy.abs(torch_out[0].detach().numpy() - ort_out[0]))) + assert numpy.allclose(torch_out[0].detach().numpy(), ort_out[0], rtol=THRESHOLD, atol=THRESHOLD) + + def benchmark(self): + self.onnx_forward(iobinding=True) + + +class TestMoE(unittest.TestCase): + def test_moe_small(self): + rt = MoE( + batch_size=2, + num_rows=8, + num_experts=4, + in_features=16, + hidden_features=32, + out_features=16, + ) + rt.parity_check() + + @pytest.mark.slow + def test_moe_large(self): + for batch_size in [1, 8]: + for num_rows in [16, 64]: + for num_experts in [16, 64]: + for in_features in [256]: + for hidden_features in [512]: + print( + f"batch_size={batch_size}, num_rows={num_rows}, num_experts={num_experts}, in_features={in_features}, hidden_features={hidden_features}" + ) + rt = MoE( + batch_size=batch_size, + num_rows=num_rows, + num_experts=num_experts, + in_features=in_features, + hidden_features=hidden_features, + out_features=in_features, + ) + rt.parity_check() + + @pytest.mark.slow + def test_moe_benchmark(self): + for batch_size in [32, 64]: + for num_rows in [128, 512]: + for num_experts in [64, 128]: + for in_features in [256, 512]: + for hidden_features in [1024, 2048]: + print( + f"batch_size={batch_size}, num_rows={num_rows}, num_experts={num_experts}, in_features={in_features}, hidden_features={hidden_features}" + ) + rt = MoE( + batch_size=batch_size, + num_rows=num_rows, + num_experts=num_experts, + in_features=in_features, + hidden_features=hidden_features, + out_features=in_features, + ) + rt.benchmark() + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py b/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py index b17ae5f69aff5..cf8128e0eebcf 100644 --- a/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py +++ b/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py @@ -261,14 +261,15 @@ def get_eps(self): eps = ["CPUExecutionProvider", "CUDAExecutionProvider"] return list(filter(lambda ep: ep in ort.get_available_providers(), eps)) - def run_ort_ep_tests(self, onnx_graph, inputs_ort, expected_output_bsnh): + def run_ort_ep_tests(self, onnx_graph, inputs_ort, expected_output_bsnh, transposed=False): eps = self.get_eps() for ep in eps: sess = ort.InferenceSession(onnx_graph, providers=[ep]) output_ort = sess.run(None, inputs_ort)[0] - output_ort = output_ort.reshape( - (self.config.batch_size, inputs_ort["input"].shape[1], self.config.num_heads, self.config.head_size) - ) + if not transposed: + output_ort = output_ort.reshape( + (self.config.batch_size, inputs_ort["input"].shape[1], self.config.num_heads, self.config.head_size) + ) # Compare outputs as BxSxNxH self.assertTrue(np.allclose(expected_output_bsnh, output_ort)) @@ -445,6 +446,44 @@ def test_hf_token_rotary_one_pos_id(self): # Compare outputs as BxSxNxH self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy()) + # Bonus test: Prompt step, interleaved = false, pos ids shape = (1), transposed + def test_hf_prompt_rotary_one_pos_id_transposed(self): + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_hf = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_hf) # output is BxNxSxH + + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = torch.tensor([0]) + onnx_graph = self.create_onnx_graph(x_bnsh.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bnsh.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Compare outputs as BxNxSxH + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.detach().cpu().numpy(), transposed=True) + + # Bonus test: Token generation step, interleaved = false, pos ids shape = (1), transposed + def test_hf_token_rotary_one_pos_id_transposed(self): + x_bnsh = torch.randn(self.config.batch_size, self.config.num_heads, 1, self.config.head_size) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_ids = torch.stack([torch.tensor([2]) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_ids) # output is BxSxNxH + + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = torch.tensor([2]) + onnx_graph = self.create_onnx_graph(x_bnsh.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bnsh.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Set tranposed=True to compare outputs as BxSxNxH + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.detach().cpu().numpy(), transposed=True) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/testdata/transform/gh_issue_18338.onnx b/onnxruntime/test/testdata/transform/gh_issue_18338.onnx new file mode 100644 index 0000000000000..afb499a347ec7 Binary files /dev/null and b/onnxruntime/test/testdata/transform/gh_issue_18338.onnx differ diff --git a/onnxruntime/test/testdata/transform/gh_issue_18338.py b/onnxruntime/test/testdata/transform/gh_issue_18338.py new file mode 100644 index 0000000000000..dc5446ac56c09 --- /dev/null +++ b/onnxruntime/test/testdata/transform/gh_issue_18338.py @@ -0,0 +1,859 @@ +import google.protobuf.text_format +import onnx +from numpy import array, float16 + +import onnxruntime as ort + +# Run n times +N = 1 + +onnx_model_text = """ +ir_version: 8 +producer_name: "pytorch" +producer_version: "2.2.0" +graph { + node { + output: "_val_1" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value_ints" + ints: -1 + type: INTS + } + doc_string: "" + } + node { + input: "input_0" + input: "_val_1" + output: "_val_2" + name: "Reshape_1" + op_type: "Reshape" + attribute { + name: "allowzero" + i: 0 + type: INT + } + doc_string: "" + } + node { + input: "_val_2" + output: "_val_3" + name: "_aten_linalg_vector_norm_no_dim_onnx_2" + op_type: "_aten_linalg_vector_norm_no_dim_onnx" + attribute { + name: "keepdim" + i: 0 + type: INT + } + attribute { + name: "ord" + f: 2.0 + type: FLOAT + } + doc_string: "" + domain: "pkg.onnxscript.torch_lib" + } + name: "main_graph" + input { + name: "input_0" + type { + tensor_type { + elem_type: 10 + shape { + } + } + } + } + output { + name: "_val_3" + type { + tensor_type { + elem_type: 10 + shape { + } + } + } + } + value_info { + name: "_val_1" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 1 + } + } + } + } + } + value_info { + name: "_val_2" + type { + tensor_type { + elem_type: 10 + shape { + dim { + dim_value: 1 + } + } + } + } + } +} +opset_import { + domain: "pkg.onnxscript.torch_lib" + version: 1 +} +opset_import { + domain: "" + version: 18 +} +opset_import { + domain: "pkg.onnxscript.torch_lib.common" + version: 1 +} +functions { + name: "_aten_linalg_vector_norm_no_dim_onnx" + input: "self" + output: "result_29" + attribute: "ord" + attribute: "keepdim" + node { + input: "self" + output: "tmp" + name: "n0" + op_type: "Shape" + domain: "" + } + node { + input: "tmp" + output: "self_rank" + name: "n1" + op_type: "Size" + domain: "" + } + node { + output: "int64_0" + name: "n2" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 7 + int64_data: 0 + name: "int64_0" + } + type: TENSOR + } + domain: "" + } + node { + input: "int64_0" + input: "self_rank" + output: "int64_0_cast" + name: "n3" + op_type: "CastLike" + domain: "" + } + node { + input: "self_rank" + input: "int64_0_cast" + output: "cond" + name: "n4" + op_type: "Equal" + domain: "" + } + node { + input: "cond" + output: "self_2" + name: "n5" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + output: "int64_0_1d" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + int64_data: 0 + name: "int64_0_1d" + } + type: TENSOR + } + domain: "" + } + node { + input: "self" + input: "int64_0_1d" + output: "self_0" + name: "n1" + op_type: "Unsqueeze" + domain: "" + } + name: "thenGraph_4" + output { + name: "self_0" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "self" + output: "self_1" + name: "n0" + op_type: "Identity" + domain: "" + } + name: "elseGraph_4" + output { + name: "self_1" + type { + } + } + } + type: GRAPH + } + domain: "" + } + node { + input: "self_2" + output: "self_3" + name: "n6" + op_type: "Abs" + domain: "" + } + node { + output: "ord" + name: "n7" + op_type: "Constant" + attribute { + name: "value_float" + type: FLOAT + ref_attr_name: "ord" + } + domain: "" + } + node { + input: "ord" + output: "ord_4" + name: "n8" + op_type: "Cast" + attribute { + name: "to" + i: 1 + type: INT + } + domain: "" + } + node { + input: "ord_4" + output: "cond_5" + name: "n9" + op_type: "IsInf" + attribute { + name: "detect_negative" + i: 0 + type: INT + } + attribute { + name: "detect_positive" + i: 1 + type: INT + } + domain: "" + } + node { + input: "cond_5" + output: "result_24" + name: "n10" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result" + name: "n0" + op_type: "ReduceMax" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_9" + output { + name: "result" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "ord_4" + output: "cond_6" + name: "n0" + op_type: "IsInf" + attribute { + name: "detect_negative" + i: 1 + type: INT + } + attribute { + name: "detect_positive" + i: 0 + type: INT + } + domain: "" + } + node { + input: "cond_6" + output: "result_23" + name: "n1" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result_7" + name: "n0" + op_type: "ReduceMin" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_11" + output { + name: "result_7" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + output: "const" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 0.0 + name: "const" + } + type: TENSOR + } + domain: "" + } + node { + input: "const" + input: "ord_4" + output: "const_cast" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "ord_4" + input: "const_cast" + output: "cond_8" + name: "n2" + op_type: "Equal" + domain: "" + } + node { + input: "cond_8" + output: "result_22" + name: "n3" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "self_bool" + name: "n0" + op_type: "Cast" + attribute { + name: "to" + i: 9 + type: INT + } + domain: "" + } + node { + input: "self_bool" + input: "self_3" + output: "self_0_1" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "self_0_1" + output: "result_9" + name: "n2" + op_type: "ReduceSum" + attribute { + name: "keepdims" + i: 0 + type: INT + } + domain: "" + } + name: "thenGraph_13" + output { + name: "result_9" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + output: "const_10" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 1.0 + name: "const_10" + } + type: TENSOR + } + domain: "" + } + node { + input: "const_10" + input: "ord_4" + output: "const_10_cast" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "ord_4" + input: "const_10_cast" + output: "cond_11" + name: "n2" + op_type: "Equal" + domain: "" + } + node { + input: "cond_11" + output: "result_21" + name: "n3" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result_12" + name: "n0" + op_type: "ReduceL1" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_18" + output { + name: "result_12" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + output: "const_13" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 2.0 + name: "const_13" + } + type: TENSOR + } + domain: "" + } + node { + input: "const_13" + input: "ord_4" + output: "const_13_cast" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "ord_4" + input: "const_13_cast" + output: "cond_14" + name: "n2" + op_type: "Equal" + domain: "" + } + node { + input: "cond_14" + output: "result_20" + name: "n3" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result_15" + name: "n0" + op_type: "ReduceL2" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_20" + output { + name: "result_15" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "ord_4" + input: "self_3" + output: "ord_float" + name: "n0" + op_type: "CastLike" + domain: "" + } + node { + input: "self_3" + input: "ord_float" + output: "self_pow" + name: "n1" + op_type: "Pow" + domain: "" + } + node { + input: "self_pow" + output: "tmp_16" + name: "n2" + op_type: "ReduceSum" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + node { + output: "const_17" + name: "n3" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 1.0 + name: "const_17" + } + type: TENSOR + } + domain: "" + } + node { + input: "const_17" + input: "ord_float" + output: "const_17_cast" + name: "n4" + op_type: "CastLike" + domain: "" + } + node { + input: "const_17_cast" + input: "ord_float" + output: "tmp_18" + name: "n5" + op_type: "Div" + domain: "" + } + node { + input: "tmp_16" + input: "tmp_18" + output: "result_19" + name: "n6" + op_type: "Pow" + domain: "" + } + name: "elseGraph_20" + output { + name: "result_19" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_18" + output { + name: "result_20" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_13" + output { + name: "result_21" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_11" + output { + name: "result_22" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_9" + output { + name: "result_23" + type { + } + } + } + type: GRAPH + } + domain: "" + } + node { + output: "int64_0_25" + name: "n11" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 7 + int64_data: 0 + name: "int64_0_25" + } + type: TENSOR + } + domain: "" + } + node { + input: "int64_0_25" + input: "self_rank" + output: "int64_0_25_cast" + name: "n12" + op_type: "CastLike" + domain: "" + } + node { + input: "self_rank" + input: "int64_0_25_cast" + output: "cond_26" + name: "n13" + op_type: "Equal" + domain: "" + } + node { + input: "cond_26" + output: "result_29" + name: "n14" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "result_24" + output: "result_27" + name: "n0" + op_type: "Squeeze" + domain: "" + } + name: "thenGraph_27" + output { + name: "result_27" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "result_24" + output: "result_28" + name: "n0" + op_type: "Identity" + domain: "" + } + name: "elseGraph_27" + output { + name: "result_28" + type { + } + } + } + type: GRAPH + } + domain: "" + } + opset_import { + domain: "" + version: 18 + } + domain: "pkg.onnxscript.torch_lib" +} +functions { + name: "Rank" + input: "input" + output: "return_val" + node { + input: "input" + output: "tmp" + name: "n0" + op_type: "Shape" + domain: "" + } + node { + input: "tmp" + output: "return_val" + name: "n1" + op_type: "Size" + domain: "" + } + doc_string: "Take the rank of the input tensor." + opset_import { + domain: "" + version: 18 + } + domain: "pkg.onnxscript.torch_lib.common" +} +functions { + name: "IsScalar" + input: "input" + output: "return_val" + node { + input: "input" + output: "tmp" + name: "n0" + op_type: "Shape" + domain: "" + } + node { + input: "tmp" + output: "tmp_0" + name: "n1" + op_type: "Size" + domain: "" + } + node { + output: "tmp_1" + name: "n2" + op_type: "Constant" + attribute { + name: "value_int" + i: 0 + type: INT + } + domain: "" + } + node { + input: "tmp_0" + input: "tmp_1" + output: "return_val" + name: "n3" + op_type: "Equal" + domain: "" + } + doc_string: "Return whether the input has rank 0, or is a scalar." + opset_import { + domain: "" + version: 18 + } + domain: "pkg.onnxscript.torch_lib.common" +} + +""" + +ort_inputs = {"input_0": array(0.8965, dtype=float16)} + +# Set up the inference session +session_options = ort.SessionOptions() +session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL +onnx_model = onnx.ModelProto() +google.protobuf.text_format.Parse(onnx_model_text, onnx_model) + +# Uncomment this line to save the model to a file for examination +# onnx.save_model(onnx_model, "test_output_match_opinfo__linalg_vector_norm_cpu_float16.onnx") + +onnx.checker.check_model(onnx_model) +session = ort.InferenceSession(onnx_model.SerializeToString(), session_options, providers=("CPUExecutionProvider",)) + +# Run the model +for _ in range(N): + ort_outputs = session.run(None, ort_inputs) diff --git a/onnxruntime/test/testdata/transform/transform_nested_ifs_toplogical_sorted_nodes.onnx b/onnxruntime/test/testdata/transform/transform_nested_ifs_toplogical_sorted_nodes.onnx new file mode 100644 index 0000000000000..afb499a347ec7 Binary files /dev/null and b/onnxruntime/test/testdata/transform/transform_nested_ifs_toplogical_sorted_nodes.onnx differ diff --git a/onnxruntime/test/testdata/transform/transform_nested_ifs_toplogical_sorted_nodes.py b/onnxruntime/test/testdata/transform/transform_nested_ifs_toplogical_sorted_nodes.py new file mode 100644 index 0000000000000..ebda865895d02 --- /dev/null +++ b/onnxruntime/test/testdata/transform/transform_nested_ifs_toplogical_sorted_nodes.py @@ -0,0 +1,859 @@ +import google.protobuf.text_format +import onnx +from numpy import array, float16 + +import onnxruntime as ort + +# Run n times +N = 1 + +onnx_model_text = """ +ir_version: 8 +producer_name: "pytorch" +producer_version: "2.2.0" +graph { + node { + output: "_val_1" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value_ints" + ints: -1 + type: INTS + } + doc_string: "" + } + node { + input: "input_0" + input: "_val_1" + output: "_val_2" + name: "Reshape_1" + op_type: "Reshape" + attribute { + name: "allowzero" + i: 0 + type: INT + } + doc_string: "" + } + node { + input: "_val_2" + output: "_val_3" + name: "_aten_linalg_vector_norm_no_dim_onnx_2" + op_type: "_aten_linalg_vector_norm_no_dim_onnx" + attribute { + name: "keepdim" + i: 0 + type: INT + } + attribute { + name: "ord" + f: 2.0 + type: FLOAT + } + doc_string: "" + domain: "pkg.onnxscript.torch_lib" + } + name: "main_graph" + input { + name: "input_0" + type { + tensor_type { + elem_type: 10 + shape { + } + } + } + } + output { + name: "_val_3" + type { + tensor_type { + elem_type: 10 + shape { + } + } + } + } + value_info { + name: "_val_1" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 1 + } + } + } + } + } + value_info { + name: "_val_2" + type { + tensor_type { + elem_type: 10 + shape { + dim { + dim_value: 1 + } + } + } + } + } +} +opset_import { + domain: "pkg.onnxscript.torch_lib" + version: 1 +} +opset_import { + domain: "" + version: 18 +} +opset_import { + domain: "pkg.onnxscript.torch_lib.common" + version: 1 +} +functions { + name: "_aten_linalg_vector_norm_no_dim_onnx" + input: "self" + output: "result_29" + attribute: "ord" + attribute: "keepdim" + node { + input: "self" + output: "tmp" + name: "n0" + op_type: "Shape" + domain: "" + } + node { + input: "tmp" + output: "self_rank" + name: "n1" + op_type: "Size" + domain: "" + } + node { + output: "int64_0" + name: "n2" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 7 + int64_data: 0 + name: "int64_0" + } + type: TENSOR + } + domain: "" + } + node { + input: "int64_0" + input: "self_rank" + output: "int64_0_cast" + name: "n3" + op_type: "CastLike" + domain: "" + } + node { + input: "self_rank" + input: "int64_0_cast" + output: "cond" + name: "n4" + op_type: "Equal" + domain: "" + } + node { + input: "cond" + output: "self_2" + name: "n5" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + output: "int64_0_1d" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + int64_data: 0 + name: "int64_0_1d" + } + type: TENSOR + } + domain: "" + } + node { + input: "self" + input: "int64_0_1d" + output: "self_0" + name: "n1" + op_type: "Unsqueeze" + domain: "" + } + name: "thenGraph_4" + output { + name: "self_0" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "self" + output: "self_1" + name: "n0" + op_type: "Identity" + domain: "" + } + name: "elseGraph_4" + output { + name: "self_1" + type { + } + } + } + type: GRAPH + } + domain: "" + } + node { + input: "self_2" + output: "self_3" + name: "n6" + op_type: "Abs" + domain: "" + } + node { + output: "ord" + name: "n7" + op_type: "Constant" + attribute { + name: "value_float" + type: FLOAT + ref_attr_name: "ord" + } + domain: "" + } + node { + input: "ord" + output: "ord_4" + name: "n8" + op_type: "Cast" + attribute { + name: "to" + i: 1 + type: INT + } + domain: "" + } + node { + input: "ord_4" + output: "cond_5" + name: "n9" + op_type: "IsInf" + attribute { + name: "detect_negative" + i: 0 + type: INT + } + attribute { + name: "detect_positive" + i: 1 + type: INT + } + domain: "" + } + node { + input: "cond_5" + output: "result_24" + name: "n10" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result" + name: "n0" + op_type: "ReduceMax" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_9" + output { + name: "result" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "ord_4" + output: "cond_6" + name: "n0" + op_type: "IsInf" + attribute { + name: "detect_negative" + i: 1 + type: INT + } + attribute { + name: "detect_positive" + i: 0 + type: INT + } + domain: "" + } + node { + input: "cond_6" + output: "result_23" + name: "n1" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result_7" + name: "n0" + op_type: "ReduceMin" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_11" + output { + name: "result_7" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + output: "const" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 0.0 + name: "const" + } + type: TENSOR + } + domain: "" + } + node { + input: "const" + input: "ord_4" + output: "const_cast" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "ord_4" + input: "const_cast" + output: "cond_8" + name: "n2" + op_type: "Equal" + domain: "" + } + node { + input: "cond_8" + output: "result_22" + name: "n3" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "self_bool" + name: "n0" + op_type: "Cast" + attribute { + name: "to" + i: 9 + type: INT + } + domain: "" + } + node { + input: "self_bool" + input: "self_3" + output: "self_0_1" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "self_0_1" + output: "result_9" + name: "n2" + op_type: "ReduceSum" + attribute { + name: "keepdims" + i: 0 + type: INT + } + domain: "" + } + name: "thenGraph_13" + output { + name: "result_9" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + output: "const_10" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 1.0 + name: "const_10" + } + type: TENSOR + } + domain: "" + } + node { + input: "const_10" + input: "ord_4" + output: "const_10_cast" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "ord_4" + input: "const_10_cast" + output: "cond_11" + name: "n2" + op_type: "Equal" + domain: "" + } + node { + input: "cond_11" + output: "result_21" + name: "n3" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result_12" + name: "n0" + op_type: "ReduceL1" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_18" + output { + name: "result_12" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + output: "const_13" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 2.0 + name: "const_13" + } + type: TENSOR + } + domain: "" + } + node { + input: "const_13" + input: "ord_4" + output: "const_13_cast" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "ord_4" + input: "const_13_cast" + output: "cond_14" + name: "n2" + op_type: "Equal" + domain: "" + } + node { + input: "cond_14" + output: "result_20" + name: "n3" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result_15" + name: "n0" + op_type: "ReduceL2" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_20" + output { + name: "result_15" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "ord_4" + input: "self_3" + output: "ord_float" + name: "n0" + op_type: "CastLike" + domain: "" + } + node { + input: "self_3" + input: "ord_float" + output: "self_pow" + name: "n1" + op_type: "Pow" + domain: "" + } + node { + input: "self_pow" + output: "tmp_16" + name: "n2" + op_type: "ReduceSum" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + node { + output: "const_17" + name: "n3" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 1.0 + name: "const_17" + } + type: TENSOR + } + domain: "" + } + node { + input: "const_17" + input: "ord_float" + output: "const_17_cast" + name: "n4" + op_type: "CastLike" + domain: "" + } + node { + input: "const_17_cast" + input: "ord_float" + output: "tmp_18" + name: "n5" + op_type: "Div" + domain: "" + } + node { + input: "tmp_16" + input: "tmp_18" + output: "result_19" + name: "n6" + op_type: "Pow" + domain: "" + } + name: "elseGraph_20" + output { + name: "result_19" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_18" + output { + name: "result_20" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_13" + output { + name: "result_21" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_11" + output { + name: "result_22" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_9" + output { + name: "result_23" + type { + } + } + } + type: GRAPH + } + domain: "" + } + node { + output: "int64_0_25" + name: "n11" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 7 + int64_data: 0 + name: "int64_0_25" + } + type: TENSOR + } + domain: "" + } + node { + input: "int64_0_25" + input: "self_rank" + output: "int64_0_25_cast" + name: "n12" + op_type: "CastLike" + domain: "" + } + node { + input: "self_rank" + input: "int64_0_25_cast" + output: "cond_26" + name: "n13" + op_type: "Equal" + domain: "" + } + node { + input: "cond_26" + output: "result_29" + name: "n14" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "result_24" + output: "result_27" + name: "n0" + op_type: "Squeeze" + domain: "" + } + name: "thenGraph_27" + output { + name: "result_27" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "result_24" + output: "result_28" + name: "n0" + op_type: "Identity" + domain: "" + } + name: "elseGraph_27" + output { + name: "result_28" + type { + } + } + } + type: GRAPH + } + domain: "" + } + opset_import { + domain: "" + version: 18 + } + domain: "pkg.onnxscript.torch_lib" +} +functions { + name: "Rank" + input: "input" + output: "return_val" + node { + input: "input" + output: "tmp" + name: "n0" + op_type: "Shape" + domain: "" + } + node { + input: "tmp" + output: "return_val" + name: "n1" + op_type: "Size" + domain: "" + } + doc_string: "Take the rank of the input tensor." + opset_import { + domain: "" + version: 18 + } + domain: "pkg.onnxscript.torch_lib.common" +} +functions { + name: "IsScalar" + input: "input" + output: "return_val" + node { + input: "input" + output: "tmp" + name: "n0" + op_type: "Shape" + domain: "" + } + node { + input: "tmp" + output: "tmp_0" + name: "n1" + op_type: "Size" + domain: "" + } + node { + output: "tmp_1" + name: "n2" + op_type: "Constant" + attribute { + name: "value_int" + i: 0 + type: INT + } + domain: "" + } + node { + input: "tmp_0" + input: "tmp_1" + output: "return_val" + name: "n3" + op_type: "Equal" + domain: "" + } + doc_string: "Return whether the input has rank 0, or is a scalar." + opset_import { + domain: "" + version: 18 + } + domain: "pkg.onnxscript.torch_lib.common" +} + +""" + +ort_inputs = {"input_0": array(0.8965, dtype=float16)} + +# Set up the inference session +session_options = ort.SessionOptions() +session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL +onnx_model = onnx.ModelProto() +google.protobuf.text_format.Parse(onnx_model_text, onnx_model) + +# Uncomment this line to save the model to a file for examination +# onnx.save_model(onnx_model, "transform_nested_ifs_toplogical_sorted_nodes.onnx") + +onnx.checker.check_model(onnx_model) +session = ort.InferenceSession(onnx_model.SerializeToString(), session_options, providers=("CPUExecutionProvider",)) + +# Run the model +for _ in range(N): + ort_outputs = session.run(None, ort_inputs) diff --git a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc new file mode 100644 index 0000000000000..0412000e04e1b --- /dev/null +++ b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/optimizer/initializer.h" +#include "orttraining/core/optimizer/conv1d_replacement.h" +#include "core/graph/graph_utils.h" + +/* + In LoRA code, it will use conv1d to do projection for qkv, + while the conv1d calculation is mathematically equivalent to MatMul, and MatMul is much faster than conv1d in GPU. + The graph transformation is doing the following graph substitution: + 1. The input graph is: + conv_input conv_weight + \ / + \ / + conv1d + + 2. The output graph is as follows, + the number of MatMul is equal to attribute "group" of conv1d + conv_input conv1d.group conv_weight conv1d.group + \ / \ / + \ / Squeeze / + \ / \ / + Split Split + / / ... \ / / ... \ + / / ... \ / / ... \ + / / ... \ / / ... \ + input0 input1 ... inputN weight0 weight1 ... weightN + \ \ \ / / / + \ \ \ / / / + \ \ \ / / / + \ \ X / / + \ \ / \ / / + \ \ / X / + \ X / \ / + \ / \ / \ / + MatMul MatMul ... MatMul + \ | ... / + \ | / + \ | / +*/ +namespace onnxruntime { +bool NodeCanBeReplacedByMatmul(const Node& node) { + // If node type is Conv, and attr "dilations" is 1, "kernel_shape" is 1, "stride" is 1, group is 1 or 2, + // then it can be replaced by MatMul + // Kernel_shape is 1 means it is conv1d + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11})) { + return false; + } + const auto* dilations = graph_utils::GetNodeAttribute(node, "dilations"); + const auto* kernel_shape = graph_utils::GetNodeAttribute(node, "kernel_shape"); + const auto* stride = graph_utils::GetNodeAttribute(node, "strides"); + const auto* group = graph_utils::GetNodeAttribute(node, "group"); + if (dilations == nullptr || kernel_shape == nullptr || stride == nullptr || group == nullptr) { + return false; + } + if ((dilations->ints_size() && dilations->ints(0) != 1) || + (kernel_shape->ints_size() && kernel_shape->ints(0) != 1) || + (stride->ints_size() && stride->ints(0) != 1) || + group->i() >= 3) { + return false; + } + + return true; +} + +void Conv1dToMatmul(Graph& graph, Node& conv) { + // Shape of conv1d input: [batch_size, in_channels, in_length] + // Shape of conv1d weight:[output_channels, input_channels/group, kernel_shape], kernel_shape is 1 + // We need to split the input into "group", and squeeze&split the weight, and then do MatMul + const std::string node_description("Conv1dReplacement"); + auto execution_provider_type = conv.GetExecutionProviderType(); + // 1. Split conv input + auto group_attr = graph_utils::GetNodeAttribute(conv, "group"); + int64_t group_num = 1; // default group is 1 from ONNX schema + if (group_attr != nullptr) { + group_num = group_attr->i(); + } + auto conv1d_input = conv.MutableInputDefs()[0]; + std::vector conv1d_input_splitted_outputs; + for (int i = 0; i < group_num; i++) { + conv1d_input_splitted_outputs.push_back(&graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("input_split_output"), nullptr)); + } + auto& input_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, {conv1d_input}, + {conv1d_input_splitted_outputs}); + input_split.SetExecutionProviderType(execution_provider_type); + input_split.AddAttribute("axis", int64_t(1)); + auto onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + if (onnx_opset_version >= 18) { + input_split.AddAttribute("num_outputs", group_num); + } + // 2. Squeeze conv weight + auto conv1d_weight = conv.MutableInputDefs()[1]; + auto weight_squeeze_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("weight_squeeze_output"), nullptr); + auto& weight_squeeze = graph.AddNode(graph.GenerateNodeName("WeightSqueeze"), "Squeeze", + node_description, {conv1d_weight}, {weight_squeeze_output}); + if (onnx_opset_version > 12) { + // After onnx version 12, squeeze node has axes as input instead of attribute + ONNX_NAMESPACE::TensorProto initializer_proto; + initializer_proto.set_name(graph.GenerateNodeName("ConstAsInitializer")); + initializer_proto.add_dims(static_cast(1)); + initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + InlinedVector initializer_proto_value{2}; + initializer_proto.set_raw_data(initializer_proto_value.data(), initializer_proto_value.size() * sizeof(int64_t)); + auto& axes_input = graph_utils::AddInitializer(graph, initializer_proto); + // Squeeze node doesn't have opschema here, so we need to set input args count manually + weight_squeeze.MutableInputArgsCount().resize(2); + graph_utils::AddNodeInput(weight_squeeze, 1, axes_input); + } else { + weight_squeeze.AddAttribute("axes", std::vector{2}); + } + weight_squeeze.SetExecutionProviderType(execution_provider_type); + // 3. Split conv weight + std::vector conv1d_weight_splitted_outputs; + for (int i = 0; i < group_num; i++) { + conv1d_weight_splitted_outputs.push_back(&graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("weight_split_output"), nullptr)); + } + auto& weight_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, + {weight_squeeze_output}, {conv1d_weight_splitted_outputs}); + weight_split.AddAttribute("axis", int64_t(0)); + weight_split.SetExecutionProviderType(execution_provider_type); + if (onnx_opset_version >= 18) { + weight_split.AddAttribute("num_outputs", group_num); + } + // 4. Do MatMul + std::vector matmul_outputs; + for (int i = 0; i < group_num; i++) { + auto matmul_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("matmul_output"), nullptr); + matmul_outputs.push_back(matmul_output); + auto& matmul = graph.AddNode(graph.GenerateNodeName("Matmul"), "MatMul", node_description, + {conv1d_weight_splitted_outputs[i], conv1d_input_splitted_outputs[i]}, + {matmul_output}); + matmul.SetExecutionProviderType(execution_provider_type); + } + // 5. Concat matmul outputs + auto& concat_node = graph.AddNode(graph.GenerateNodeName("Concat"), "Concat", node_description, + matmul_outputs, {}); + concat_node.SetExecutionProviderType(execution_provider_type); + concat_node.AddAttribute("axis", int64_t(1)); + // 6. Clean up - delted original "conv" node, its output is replaced by concat_node + graph_utils::FinalizeNodeFusion(graph, concat_node, conv); +} + +Status Conv1dReplacement::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + for (auto node_index : node_topology_list) { + auto* node_ptr = graph.GetNode(node_index); + if (!node_ptr) + continue; // node was removed + auto& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + if (NodeCanBeReplacedByMatmul(node)) { + LOGS(logger, VERBOSE) << "lora conv1d replacement, node name: " + node.Name(); + Conv1dToMatmul(graph, node); + modified = true; + } + } + return Status::OK(); +} +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/conv1d_replacement.h b/orttraining/orttraining/core/optimizer/conv1d_replacement.h new file mode 100644 index 0000000000000..740f13c76fd6f --- /dev/null +++ b/orttraining/orttraining/core/optimizer/conv1d_replacement.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +class Conv1dReplacement : public GraphTransformer { + public: + Conv1dReplacement(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("Conv1dReplacement", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 57d76577f1ba7..6193a1d10c095 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -72,6 +72,7 @@ #ifdef ENABLE_TRAINING_TORCH_INTEROP #include "orttraining/core/optimizer/pythonop_rewriter.h" #endif +#include "orttraining/core/optimizer/conv1d_replacement.h" namespace onnxruntime { namespace training { @@ -194,6 +195,7 @@ std::vector> GeneratePreTrainingTransformers( // Once we have a CPU kernel for PadAndUnflatten, we can remove the guard. transformers.emplace_back(std::make_unique(compatible_eps, config.sparse_embedding_input_names)); + transformers.emplace_back(std::make_unique(compatible_eps)); #endif } diff --git a/orttraining/orttraining/python/checkpointing_utils.py b/orttraining/orttraining/python/checkpointing_utils.py deleted file mode 100644 index 460b9982297d1..0000000000000 --- a/orttraining/orttraining/python/checkpointing_utils.py +++ /dev/null @@ -1,127 +0,0 @@ -import os - -import torch - - -def list_checkpoint_files(checkpoint_dir, checkpoint_prefix, extension=".ort.pt"): - ckpt_file_names = [f for f in os.listdir(checkpoint_dir) if f.startswith(checkpoint_prefix)] - ckpt_file_names = [f for f in ckpt_file_names if f.endswith(extension)] - ckpt_file_names = [os.path.join(checkpoint_dir, f) for f in ckpt_file_names] - - assert len(ckpt_file_names) > 0, 'No checkpoint files found with prefix "{}" in directory {}.'.format( - checkpoint_prefix, checkpoint_dir - ) - return ckpt_file_names - - -def get_checkpoint_name(prefix, is_partitioned, world_rank=None, world_size=None): - SINGLE_CHECKPOINT_FILENAME = "{prefix}.ort.pt" # noqa: N806 - MULTIPLE_CHECKPOINT_FILENAME = "{prefix}.ZeRO.{world_rank}.{world_size}.ort.pt" # noqa: N806 - - if is_partitioned: - filename = MULTIPLE_CHECKPOINT_FILENAME.format( - prefix=prefix, world_rank=world_rank, world_size=(world_size - 1) - ) - else: - filename = SINGLE_CHECKPOINT_FILENAME.format(prefix=prefix) - - return filename - - -def _split_state_dict(state_dict): - optimizer_keys = ["Moment_1_", "Moment_2_", "Update_Count_", "Step"] - split_sd = {"optimizer": {}, "fp32_param": {}, "fp16_param": {}} - for k, v in state_dict.items(): - mode = "fp32_param" - for optim_key in optimizer_keys: - if k.startswith(optim_key): - mode = "optimizer" - break - if k.endswith("_fp16"): - mode = "fp16_param" - split_sd[mode][k] = v - return split_sd - - -class CombineZeroCheckpoint: - def __init__(self, checkpoint_files, clean_state_dict=None): - assert len(checkpoint_files) > 0, "No checkpoint files passed" - self.checkpoint_files = checkpoint_files - self.clean_state_dict = clean_state_dict - self.world_size = int(self.checkpoint_files[0].split("ZeRO")[1].split(".")[2]) + 1 - assert len(self.checkpoint_files) == self.world_size, f"Could not find {self.world_size} files" - self.weight_shape_map = dict() - self.sharded_params = set() - - def _split_name(self, name: str): - name_split = name.split("_view_") - view_num = None - if len(name_split) > 1: - view_num = int(name_split[1]) - optimizer_key = "" - mp_suffix = "" - if name_split[0].startswith("Moment_1"): - optimizer_key = "Moment_1_" - elif name_split[0].startswith("Moment_2"): - optimizer_key = "Moment_2_" - elif name_split[0].startswith("Update_Count"): - optimizer_key = "Update_Count_" - elif name_split[0].endswith("_fp16"): - mp_suffix = "_fp16" - param_name = name_split[0] - if optimizer_key: - param_name = param_name.split(optimizer_key)[1] - param_name = param_name.split("_fp16")[0] - return param_name, optimizer_key, view_num, mp_suffix - - def _update_weight_statistics(self, name, value): - if name not in self.weight_shape_map: - self.weight_shape_map[name] = value.size() # original shape of tensor - - def _reshape_tensor(self, key): - value = self.aggregate_state_dict[key] - weight_name, _, _, _ = self._split_name(key) - set_size = self.weight_shape_map[weight_name] - self.aggregate_state_dict[key] = value.reshape(set_size) - - def _aggregate(self, param_dict): - for k, v in param_dict.items(): - weight_name, optimizer_key, view_num, mp_suffix = self._split_name(k) - if view_num is not None: - # parameter is sharded - param_name = optimizer_key + weight_name + mp_suffix - - if param_name in self.aggregate_state_dict and optimizer_key not in ["Update_Count_"]: - self.sharded_params.add(param_name) - # Found a previous shard of the param, concatenate shards ordered by ranks - self.aggregate_state_dict[param_name] = torch.cat((self.aggregate_state_dict[param_name], v)) - else: - self.aggregate_state_dict[param_name] = v - else: - if k in self.aggregate_state_dict: - assert (self.aggregate_state_dict[k] == v).all(), "Unsharded params must have the same value" - else: - self.aggregate_state_dict[k] = v - self._update_weight_statistics(weight_name, v) - - def aggregate_checkpoints(self): - checkpoint_prefix = self.checkpoint_files[0].split(".ZeRO")[0] - self.aggregate_state_dict = dict() - - for i in range(self.world_size): - checkpoint_name = get_checkpoint_name(checkpoint_prefix, True, i, self.world_size) - rank_state_dict = torch.load(checkpoint_name, map_location=torch.device("cpu")) - if "model" in rank_state_dict: - rank_state_dict = rank_state_dict["model"] - - if self.clean_state_dict: - rank_state_dict = self.clean_state_dict(rank_state_dict) - - rank_state_dict = _split_state_dict(rank_state_dict) - self._aggregate(rank_state_dict["fp16_param"]) - self._aggregate(rank_state_dict["fp32_param"]) - self._aggregate(rank_state_dict["optimizer"]) - - for k in self.sharded_params: - self._reshape_tensor(k) - return self.aggregate_state_dict diff --git a/orttraining/orttraining/python/deprecated/__init__.py b/orttraining/orttraining/python/deprecated/__init__.py deleted file mode 100644 index 6e02db707bc47..0000000000000 --- a/orttraining/orttraining/python/deprecated/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -from onnxruntime.capi._pybind_state import TrainingParameters # noqa: F401 -from onnxruntime.capi.training.training_session import TrainingSession # noqa: F401 diff --git a/orttraining/orttraining/python/deprecated/training_session.py b/orttraining/orttraining/python/deprecated/training_session.py deleted file mode 100644 index a6900578e174b..0000000000000 --- a/orttraining/orttraining/python/deprecated/training_session.py +++ /dev/null @@ -1,68 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import os # noqa: F401 -import sys # noqa: F401 - -from onnxruntime.capi import _pybind_state as C -from onnxruntime.capi.onnxruntime_inference_collection import IOBinding # noqa: F401 -from onnxruntime.capi.onnxruntime_inference_collection import ( - InferenceSession, - Session, - check_and_normalize_provider_args, -) - - -class TrainingSession(InferenceSession): - def __init__(self, path_or_bytes, parameters, sess_options=None, providers=None, provider_options=None): - Session.__init__(self) - - if sess_options: - self._sess = C.TrainingSession(sess_options) - else: - self._sess = C.TrainingSession() - - # providers needs to be passed explicitly as of ORT 1.10 - # retain the pre-1.10 behavior by setting to the available providers. - if providers is None: - providers = C.get_available_providers() - - providers, provider_options = check_and_normalize_provider_args( - providers, provider_options, C.get_available_providers() - ) - - if isinstance(path_or_bytes, str): - config_result = self._sess.load_model(path_or_bytes, parameters, providers, provider_options) - elif isinstance(path_or_bytes, bytes): - config_result = self._sess.read_bytes(path_or_bytes, parameters, providers, provider_options) - else: - raise TypeError(f"Unable to load from type '{type(path_or_bytes)}'") - - self.loss_scale_input_name = config_result.loss_scale_input_name - - self._inputs_meta = self._sess.inputs_meta - self._outputs_meta = self._sess.outputs_meta - - def __del__(self): - if self._sess: - self._sess.finalize() - - def get_state(self): - return self._sess.get_state() - - def get_model_state(self, include_mixed_precision_weights=False): - return self._sess.get_model_state(include_mixed_precision_weights) - - def get_optimizer_state(self): - return self._sess.get_optimizer_state() - - def get_partition_info_map(self): - return self._sess.get_partition_info_map() - - def load_state(self, dict, strict=False): - self._sess.load_state(dict, strict) - - def is_output_fp32_node(self, output_name): - return self._sess.is_output_fp32_node(output_name) diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py deleted file mode 100644 index 5286c087cfb64..0000000000000 --- a/orttraining/orttraining/python/ort_trainer.py +++ /dev/null @@ -1,1241 +0,0 @@ -import io -import os -import warnings - -import numpy as np -import onnx -import torch -import torch.nn -import torch.onnx -from onnx import helper, numpy_helper -from packaging.version import Version as LooseVersion - -import onnxruntime as ort -import onnxruntime.capi.pt_patch -from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference - -from ..training import postprocess -from .checkpointing_utils import CombineZeroCheckpoint, get_checkpoint_name, list_checkpoint_files - -DEFAULT_OPSET_VERSION = 14 - - -class IODescription: - def __init__(self, name, shape, dtype=None, num_classes=None): - self.name_ = name - self.shape_ = shape - self.dtype_ = dtype - self.num_classes_ = num_classes - - -class ModelDescription: - def __init__(self, inputs, outputs): - self.inputs_ = inputs - self.outputs_ = outputs - - -def resolve_symbolic_dimensions(inputs, input_descs, output_descs): - import copy - - output_descs_copy = copy.deepcopy(output_descs) - resolved_dims = {} - for input, input_desc in zip(inputs, input_descs): - for i, axis in enumerate(input_desc.shape_): - if isinstance(axis, str): - resolved_dims[axis] = input.size()[i] - - for output_desc in output_descs_copy: - for i, axis in enumerate(output_desc.shape_): - if isinstance(axis, str): - output_desc.shape_[i] = resolved_dims[axis] - - if any(isinstance(axis, str) for axis in output_desc.shape_ for output_desc in output_descs): - raise RuntimeError("Cannot run model with unknown output dimensions") - - return output_descs_copy - - -def generate_sample(desc, device=None): - # symbolic dimensions are described with strings. set symbolic dimensions to be 1 - size = [s if isinstance(s, (int)) else 1 for s in desc.shape_] - if desc.num_classes_: - return torch.randint(0, desc.num_classes_, size, dtype=desc.dtype_).to(device) - else: - return torch.randn(size, dtype=desc.dtype_).to(device) - - -def get_device_index(device): - if type(device) == str: # noqa: E721 - # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 - device = torch.device(device) - return 0 if device.index is None else device.index - - -def input_get_device_index(input): - if isinstance(input, (list, tuple)): - device_index = get_device_index(input[0].device) - else: - device_index = get_device_index(input.device) - - return device_index - - -def get_all_gradients_finite_arg_name(session): - all_fp16_or_fp32_gradients_finite_node_args = [x for x in session._outputs_meta if "all_gradients_finite" in x.name] - if len(all_fp16_or_fp32_gradients_finite_node_args) < 1: - raise RuntimeError( - "Failed to find a group NodeArg with name that matches 'all_gradients_finite'\ - from the training session." - ) - - return all_fp16_or_fp32_gradients_finite_node_args[0].name - - -def get_group_accumulated_gradients_output_node_arg_name(session): - # TODO: get the constant string via pybind. - # optimizer_graph_builder BuildGroupNode with fixed string: 'Group_Accumulated_Gradients' - accumulated_gradients_output_node_args = [ - x for x in session._outputs_meta if "Group_Accumulated_Gradients" in x.name - ] - if len(accumulated_gradients_output_node_args) != 1: - raise RuntimeError( - "Failed to find a group NodeArg with name that matches 'Group_Accumulated_Gradients'\ - from the training session." - ) - - return accumulated_gradients_output_node_args[0].name - - -def ort_training_session_run_helper(session, iobinding, inputs, input_descs, output_descs, device, run_options=None): - for input, input_desc in zip(inputs, input_descs): - device_index = input_get_device_index(input) - iobinding.bind_input( - input_desc.name_, - input.device.type, - device_index, - dtype_torch_to_numpy(input.dtype), - list(input.size()), - input.data_ptr(), - ) - - output_descs_resolved = resolve_symbolic_dimensions(inputs, input_descs, output_descs) - torch_outputs = {} - for output_desc in output_descs_resolved: - torch_tensor = torch.zeros( - output_desc.shape_, - device=device, - dtype=output_desc.eval_dtype_ if hasattr(output_desc, "eval_dtype_") else output_desc.dtype_, - ) - iobinding.bind_output( - output_desc.name_, - torch_tensor.device.type, - get_device_index(device), - dtype_torch_to_numpy(torch_tensor.dtype), - list(torch_tensor.size()), - torch_tensor.data_ptr(), - ) - torch_outputs[output_desc.name_] = torch_tensor - - session.run_with_iobinding(iobinding, run_options) - return torch_outputs - - -def FuseSofmaxNLLToSoftmaxCE(onnx_model): # noqa: N802 - nll_count = 0 - while True: - nll_count = nll_count + 1 - nll_loss_node = None - nll_loss_node_index = 0 - for nll_loss_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007 - if node.op_type == "nll_loss" or node.op_type == "NegativeLogLikelihoodLoss": - nll_loss_node = node - break - - if nll_loss_node is None: - break - - softmax_node = None - softmax_node_index = 0 - label_input_name = None - weight_input_name = None - for softmax_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007 - if node.op_type == "LogSoftmax": - # has to be connected to nll_loss - if len(nll_loss_node.input) > 2: - weight_input_name = nll_loss_node.input[2] - if node.output[0] == nll_loss_node.input[0]: - softmax_node = node - label_input_name = nll_loss_node.input[1] - break - elif node.output[0] == nll_loss_node.input[1]: - softmax_node = node - label_input_name = nll_loss_node.input[0] - break - else: - if softmax_node is not None: - break - - if softmax_node is None: - break - - # delete nll_loss and LogSoftmax nodes in order - if nll_loss_node_index < softmax_node_index: - del onnx_model.graph.node[softmax_node_index] - del onnx_model.graph.node[nll_loss_node_index] - else: - del onnx_model.graph.node[nll_loss_node_index] - del onnx_model.graph.node[softmax_node_index] - - probability_output_name = softmax_node.output[0] - node = onnx_model.graph.node.add() - inputs = ( - [softmax_node.input[0], label_input_name, weight_input_name] - if weight_input_name - else [softmax_node.input[0], label_input_name] - ) - node.CopyFrom( - onnx.helper.make_node( - "SparseSoftmaxCrossEntropy", - inputs, - [nll_loss_node.output[0], probability_output_name], - "nll_loss_node_" + str(nll_count), - ) - ) - - return onnx_model - - -def delete_input_with_name(input, name): - index = 0 - for i in input: - if i.name == name: - del input[index] - break - index = index + 1 - - -# reference: -# https://docs.scipy.org/doc/numpy-1.13.0/user/basics.types.html -# https://pytorch.org/docs/stable/tensors.html -# also must map to types accepted by: -# MLDataType NumpyTypeToOnnxRuntimeType(int numpy_type) -def dtype_torch_to_numpy(torch_dtype): - if torch_dtype == torch.float64 or torch_dtype == torch.double: - return np.float64 - elif torch_dtype == torch.float32 or torch_dtype == torch.float: - return np.float32 - elif torch_dtype == torch.float16 or torch_dtype == torch.half: - return np.float16 - elif torch_dtype == torch.int64 or torch_dtype == torch.long: - return np.longlong - elif torch_dtype == torch.int32 or torch_dtype == torch.int: - return np.int32 - elif torch_dtype == torch.int16 or torch_dtype == torch.short: - return np.int16 - elif torch_dtype == torch.bool: - return bool - else: - raise Exception("Torch type to numpy type mapping unavailable for: " + str(torch_dtype)) - - -class model_loss_cls(torch.nn.Module): # noqa: N801 - def __init__(self, model, loss_fn): - super().__init__() - self.model_ = model - self.loss_fn_ = loss_fn - - def forward(self, *inputs): - # here we assume input can be unpacked into input and label - input, label = inputs[:-1], inputs[-1] - preds = self.model_(*input) - return self.loss_fn_(preds, label), preds - - -class WrapModel(torch.nn.Module): - def __init__(self, model, loss_fn, input_names): - super().__init__() - self.model_ = model - self.loss_fn_ = loss_fn - self.input_names_ = input_names - - def forward(self, *inputs): - import inspect - - # *inputs is given by torch trace. It is in the order of input_names. - # model_ takes input in a order (which can be obtained via inspect.signature(model.forward)) different than input_names. - sig = inspect.signature(self.model_.forward) - list(sig.parameters.keys()) - - input_dict = {} - for key in sig.parameters: - if key in self.input_names_: - input_dict[key] = inputs[self.input_names_.index(key)] - - model_out = self.model_(**input_dict) - if self.loss_fn_ is None: - return model_out - - label = inputs[-1] - preds = model_out - return self.loss_fn_(preds, label), preds - - -def wrap_for_input_match(model, loss_fn, input_names): - import inspect - - sig = inspect.signature(model.forward) - ordered_list_keys = list(sig.parameters.keys()) - if loss_fn: - sig_loss = inspect.signature(loss_fn) - if len(sig_loss.parameters) != 2: - raise RuntimeError("loss function should take two arguments - predict and label.") - - # label shall be the second input to loss_fn. - ordered_list_keys = [*ordered_list_keys, list(sig_loss.parameters.keys())[1]] - - # name match is needed only when input_names are a subset - # of expected inputs (inputs to model and loss_fn combined). - if len(input_names) > len(ordered_list_keys): - # this is likely the case where input arguments are packed. - # TODO: to unpack the input argument. - return model_loss_cls(model, loss_fn) if loss_fn else model - elif len(input_names) == len(ordered_list_keys): - # in this case, we do not require name match. - return model_loss_cls(model, loss_fn) if loss_fn else model - - if not all(x in ordered_list_keys for x in input_names): - # model desc has name(s) not matching the model signature. We cannot do anything in this case. - # better to warning the user. - return model_loss_cls(model, loss_fn) if loss_fn else model - - # if input_names match ordered_list_keys, there is not need for wrapping - match = True - for i, input_name in enumerate(input_names): - if input_name != ordered_list_keys[i]: - match = False - break - - if match: - return model_loss_cls(model, loss_fn) if loss_fn else model - - model = WrapModel(model, loss_fn, input_names) - - return model - - -def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, opset_version=DEFAULT_OPSET_VERSION): - # example: {input0:{0:'batch'}, input1:{0:'batch'}} - dynamic_axes = {} - for input in model_desc.inputs_: - symbolic_axis = {} - for i, axis in enumerate(input.shape_): - if isinstance(axis, str): - symbolic_axis[i] = axis - if len(symbolic_axis): - dynamic_axes[input.name_] = symbolic_axis - - for output in model_desc.outputs_: - symbolic_axis = {} - for i, axis in enumerate(output.shape_): - if isinstance(axis, str): - symbolic_axis[i] = axis - if len(symbolic_axis): - dynamic_axes[output.name_] = symbolic_axis - - input_names = [input.name_ for input in model_desc.inputs_] - output_names = [output.name_ for output in model_desc.outputs_] - - if isinstance(inputs, torch.Tensor): - inputs = [inputs] - if isinstance(inputs, dict): - sample_inputs = [inputs[k.name_].to(device=device) for k in model_desc.inputs_] - elif isinstance(inputs, (list, tuple)): - sample_inputs = [input.to(device=device) for i, input in enumerate(inputs) if i < len(model_desc.inputs_)] - else: - raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.") - - # pytorch onnx exporter/trace does not try to match argument names. - # e.g. for models with optional inputs, it requires all inputs be present. - # this is a problem because the model graph depends on inputs provided. - model = wrap_for_input_match(model, loss_fn, input_names) - - model.eval() - with torch.no_grad(): - import copy - - # Deepcopy inputs, since input values may change after model run. - sample_inputs_copy = copy.deepcopy(sample_inputs) - try: - # Deepcopy model, in case model is stateful and changes after model run. - model_copy = copy.deepcopy(model) - except Exception: - model_copy = model - warnings.warn( - "This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX." - " Compute will continue, but unexpected results may occur!" - ) - - sample_outputs = model_copy(*sample_inputs_copy) - if isinstance(sample_outputs, torch.Tensor): - sample_outputs = [sample_outputs] - for sample_output, output_desc in zip(sample_outputs, model_desc.outputs_): - output_desc.dtype_ = sample_output.dtype - model.train() - - f = io.BytesIO() - - # Other export options to use(this is for backward compatibility). - other_export_options = {} - other_export_options["training"] = True - - # This option was added after 1.4 release. - if LooseVersion(torch.__version__) > LooseVersion("1.4.0") and LooseVersion(torch.__version__) < LooseVersion( - "1.10.0" - ): - other_export_options["enable_onnx_checker"] = False - # This option was added after 1.6 release. - if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): - other_export_options["training"] = torch.onnx.TrainingMode.TRAINING - - # Deepcopy inputs, since input values may change after model run. - import copy - - sample_inputs_copy = copy.deepcopy(sample_inputs) - - # Enable contrib ops export from PyTorch - from onnxruntime.tools import pytorch_export_contrib_ops - - pytorch_export_contrib_ops.register() - - torch.onnx._export( - model, - tuple(sample_inputs_copy), - f, - input_names=input_names, - output_names=output_names, - opset_version=opset_version, - dynamic_axes=dynamic_axes, - do_constant_folding=False, - **other_export_options, - ) - - onnx_model = onnx.load_model_from_string(f.getvalue()) - - # Remove 'model_.' prefix introduced by model wrapper for initializers. - if isinstance(model, (WrapModel, model_loss_cls)): - replace_name_dict = {} - for n in onnx_model.graph.initializer: - if n.name.startswith("model_."): - replace_name_dict[n.name] = n.name[len("model_.") :] - n.name = replace_name_dict[n.name] - for n in onnx_model.graph.node: - for i, name in enumerate(n.input): - if name in replace_name_dict: - n.input[i] = replace_name_dict[name] - - return onnx_model - - -def create_ort_training_session_with_optimizer( - model, - device, - training_optimizer_name, - lr_params_feed_name, - map_optimizer_attributes, - world_rank=-1, - world_size=1, - gradient_accumulation_steps=1, - bind_parameters=False, - use_mixed_precision=False, - allreduce_post_accumulation=False, - deepspeed_zero_stage=0, - enable_grad_norm_clip=True, - frozen_weights=[], # noqa: B006 - opset_version=DEFAULT_OPSET_VERSION, - use_deterministic_compute=False, - use_memory_efficient_gradient=False, - enable_adasum=False, - optimized_model_filepath="", -): - output_name = model.graph.output[0].name - ort_parameters = ort.TrainingParameters() - ort_parameters.loss_output_name = output_name - ort_parameters.use_mixed_precision = use_mixed_precision - ort_parameters.world_rank = world_rank - ort_parameters.world_size = world_size - ort_parameters.gradient_accumulation_steps = gradient_accumulation_steps - ort_parameters.allreduce_post_accumulation = allreduce_post_accumulation - ort_parameters.deepspeed_zero_stage = deepspeed_zero_stage - ort_parameters.enable_grad_norm_clip = enable_grad_norm_clip - ort_parameters.set_gradients_as_graph_outputs = False - ort_parameters.use_memory_efficient_gradient = use_memory_efficient_gradient - ort_parameters.enable_adasum = enable_adasum - output_types = {} - for output in model.graph.output: - output_types[output.name] = output.type.tensor_type - - # pybind does not allow to add directly to ort_parameters.weights_to_train. - # Have to work around by using a temporary weights_to_train. - torch_params = {} - optimizer_attributes_map = {} - optimizer_int_attributes_map = {} - - unused_frozen_weights = [n for n in frozen_weights if n not in [i.name for i in model.graph.initializer]] - if unused_frozen_weights: - raise RuntimeError(f"{unused_frozen_weights} in frozen_weights not found in model weights.") - - weights_to_train = set() - for initializer in model.graph.initializer: - if initializer.name in frozen_weights: - continue - weights_to_train.add(initializer.name) - if map_optimizer_attributes is not None: - attributes = map_optimizer_attributes(initializer.name) - optimizer_attributes_map[initializer.name] = {} - optimizer_int_attributes_map[initializer.name] = {} - for k, v in attributes.items(): - if isinstance(v, float): - optimizer_attributes_map[initializer.name][k] = v - elif isinstance(v, int): - optimizer_int_attributes_map[initializer.name][k] = v - else: - raise ValueError("Optimizer attributes must be either float or int.") - else: - optimizer_attributes_map[initializer.name] = {} - optimizer_int_attributes_map[initializer.name] = {} - - if bind_parameters: - for initializer in model.graph.initializer: - torch_tensor = torch.nn.Parameter(torch.as_tensor(numpy_helper.to_array(initializer), device=device)) - delete_input_with_name(model.graph.input, initializer.name) - model.graph.input.extend( - [helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims)] - ) - torch_params[initializer.name] = torch_tensor - - del model.graph.initializer[:] - - ort_parameters.weights_to_train = weights_to_train - ort_parameters.training_optimizer_name = training_optimizer_name - ort_parameters.lr_params_feed_name = lr_params_feed_name - ort_parameters.optimizer_attributes_map = optimizer_attributes_map - ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map - - sessionOptions = ort.SessionOptions() # noqa: N806 - sessionOptions.use_deterministic_compute = use_deterministic_compute - if len(optimized_model_filepath) > 0: - sessionOptions.optimized_model_filepath = optimized_model_filepath - session = ort.TrainingSession(model.SerializeToString(), ort_parameters, sessionOptions) - train_io_binding = session.io_binding() - eval_io_binding = session.io_binding() - - if bind_parameters: - for param in torch_params: - torch_tensor = torch_params[param] - - train_io_binding.bind_input( - param, - torch_tensor.device.type, - get_device_index(torch_tensor.device), - dtype_torch_to_numpy(torch_params[param].dtype), - list(torch_tensor.size()), - torch_tensor.data_ptr(), - ) - eval_io_binding.bind_input( - param, - torch_tensor.device.type, - get_device_index(torch_tensor.device), - dtype_torch_to_numpy(torch_params[param].dtype), - list(torch_tensor.size()), - torch_tensor.data_ptr(), - ) - - return session, train_io_binding, eval_io_binding, output_name, torch_params, output_types - - -def save_checkpoint( - model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", checkpoint_state_dict=None, include_optimizer_state=True -): - if checkpoint_state_dict is None: - checkpoint_state_dict = {"model": model.state_dict(include_optimizer_state)} - else: - checkpoint_state_dict.update({"model": model.state_dict(include_optimizer_state)}) - - assert os.path.exists(checkpoint_dir), f"ERROR: Checkpoint directory doesn't exist: {checkpoint_dir}" - - checkpoint_name = get_checkpoint_name( - checkpoint_prefix, model.deepspeed_zero_stage_, model.world_rank, model.world_size - ) - checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) - - if os.path.exists(checkpoint_file): - warnings.warn(f"{checkpoint_file} already exists, overwriting.") - - torch.save(checkpoint_state_dict, checkpoint_file) - - -def _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict): - checkpoint_name = get_checkpoint_name(checkpoint_prefix, is_partitioned, model.world_rank, model.world_size) - checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) - - if is_partitioned: - assert_msg = ( - f"Couldn't find checkpoint file {checkpoint_file}." - "Optimizer partitioning is enabled using ZeRO. Please make sure that the " - f"checkpoint file exists for rank {model.world_rank} of {model.world_size}." - ) - else: - assert_msg = f"Couldn't find checkpoint file {checkpoint_file}." - - assert os.path.exists(checkpoint_file), assert_msg - - checkpoint_state = torch.load(checkpoint_file, map_location="cpu") - - model.load_state_dict(checkpoint_state["model"], strict=strict) - del checkpoint_state["model"] - return checkpoint_state - - -def _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict): - checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix) - - ckpt_agg = CombineZeroCheckpoint(checkpoint_files) - aggregate_state_dict = ckpt_agg.aggregate_checkpoints() - - model.load_state_dict(aggregate_state_dict, strict=strict) - - # aggregate other keys in the state_dict. - # Values will be overwritten for matching keys among workers - all_checkpoint_states = {} - for checkpoint_file in checkpoint_files: - checkpoint_state = torch.load(checkpoint_file, map_location="cpu") - del checkpoint_state["model"] - all_checkpoint_states.update(checkpoint_state) - return all_checkpoint_states - - -def load_checkpoint(model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False): - checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix) - is_partitioned = False - if len(checkpoint_files) > 1: - warnings.warn( - f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." - "Attempting to load ZeRO checkpoint." - ) - is_partitioned = True - if (not model.deepspeed_zero_stage_) and is_partitioned: - return _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict) - else: - return _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict) - - -class ORTTrainer: - def __init__( - self, - model, - loss_fn, - model_desc, - training_optimizer_name, - map_optimizer_attributes, - learning_rate_description, - device, - gradient_accumulation_steps=1, - world_rank=0, - world_size=1, - use_mixed_precision=False, - allreduce_post_accumulation=False, - global_step=0, - get_lr_this_step=None, - loss_scaler=None, - deepspeed_zero_stage=0, - enable_grad_norm_clip=True, - frozen_weights=[], # noqa: B006 - _opset_version=DEFAULT_OPSET_VERSION, - _enable_internal_postprocess=True, - _extra_postprocess=None, - _use_deterministic_compute=False, - use_memory_efficient_gradient=False, - run_symbolic_shape_infer=False, - enable_adasum=False, - optimized_model_filepath="", - ): - super().__init__() - """ - Initialize ORTTrainer. - - Args: - - model: one of - - a PyTorch model (class that inherits from torch.nn.Module) - - a combined PyTorch model and loss function. - Inputs to this combined PyTorch model are a concatenation of the - model's input and the loss function's label input. - Outputs are a concatenation of the loss function's output and the - model's output. - - a combined ONNX model and loss function. - loss_fn: one of - - a PyTorch loss function if 'model' is a PyTorch model. A loss - function takes two inputs (prediction, label) and outputs a loss - tensor. - - None if model is already combined with a loss function. - model_desc: Specify input/output shapes, types, and names. - Must be consistent with the training model. - training_optimizer_name: one of - - 'SGDOptimizer' - - 'AdamOptimizer' - - 'LambOptimizer' - map_optimizer_attributes: for optimizers with weight-dependent - parameters. A callable that maps weight name to a set of optimization - parameters. - Defaults to None. - learning_rate_description: the name, shape and type of the learning - rate in form of IODescription(Learning_Rate_Name, [1,], torch.float32). - Because learning_rate is an input to the training model, - Learning_Rate_Name must be specified so that there is no name conflict - within the model. - device: device to store tensors (e.g. 'cpu', 'cuda', 'cuda:'). - gradient_accumulation_steps: number of training steps to accumulate - gradients before averaging and applying them. - Defaults to 1. - world_rank: rank id used for distributed training. - Defaults to 0. - world_size: number of ranks participating in distributed training. - Defaults to 1. - use_mixed_precision: flag to enable mixed precision (aka fp16). - Defaults to False. - allreduce_post_accumulation: controls whether overlaping gradient - computation is applied with allreduce. - Defaults to False. - global_step: training step that is used as input to 'get_lr_this_step'. - Defaults to 0. - get_lr_this_step: functor used as learning rate scheduler. - It uses 'global_step' as input. - Defaults to None. - loss_scaler: updates loss scale automatically when 'use_mixed_precision' - is specified. - Defaults to None. - deepspeed_zero_stage: controls whether to partition state using the DeepSpeed ZeRO technique. Stages 0 and 1 are supported. - Defaults to 0 (disabled). - enable_grad_norm_clip: enables gradient norm clipping. - Defaults to True. - frozen_weights: list of model parameters to be frozen (not trained). - Defaults to []. - _enable_internal_postprocess: whether to run or not the internal postprocesses. - Defaults to True - _extra_postprocess: a callable to postprocess the ONNX model that is converted from PyTorch. - Defaults to None - use_memory_efficient_gradient: use memory aware gradient builder. - Defaults to False - run_symbolic_shape_infer: run symbolic shape inference - Defaults to False - optimized_model_filepath: path to output the optimized training graph. - Defaults to "" (no output). - """ - warnings.warn( - "ORTTrainer is deprecated and will be removed in ort release 1.14. Please use ORTModule instead.", - FutureWarning, - ) - warnings.warn( - "DISCLAIMER: This is an early version of an experimental training API and it is subject to change. DO NOT create production applications with it" - ) - self.is_train = True - - self.torch_model_ = None - self.onnx_model_ = None - self._enable_internal_postprocess = _enable_internal_postprocess - self._extra_postprocess = _extra_postprocess - - if isinstance(model, torch.nn.Module): - self.torch_model_ = model - self.loss_fn_ = loss_fn - self._torch_state_dict_keys = list(model.state_dict().keys()) - else: - self._torch_state_dict_keys = [] - self.onnx_model_ = model - if loss_fn is not None: - warnings.warn("loss_fn is not used when creating ORTTrainer because an ONNX model is provided.") - # TODO: accept loss_fn as an onnx model. build self.onnx_model_ with model and loss_fn - self.loss_fn_ = None - - if self._enable_internal_postprocess: - postprocess.run_postprocess(self.onnx_model_) - - if self._extra_postprocess: - self._extra_postprocess(self.onnx_model_) - - self.model_desc_ = model_desc - self.input_desc_with_lr = [*self.model_desc_.inputs_, learning_rate_description] - - self.world_rank = world_rank - self.world_size = world_size - self.use_mixed_precision = use_mixed_precision - - self.session = None - self.device_ = device - self.gradient_accumulation_steps = gradient_accumulation_steps - # we use self.current_step to count calls to train_step. It is used for gradient accumulation. - # gradients are being accumulated when self.current_step is not divisible by gradient_accumulation_steps. - # gradients are updated when self.current_step is divisible by gradient_accumulation_steps. - self.current_step = 0 - - # we use self.global_step_ to count optimizations being performed. - # it is used to calculate learning rate if self.get_lr_this_step_ is provided. - self.global_step_ = global_step - self.get_lr_this_step_ = get_lr_this_step - self.loss_scaler_ = loss_scaler - - if self.get_lr_this_step_ is not None or self.loss_scaler_ is not None: - warnings.warn("It is experimental to use learning rate scheduler and loss scaler inside ORTTrainer.") - self.training_optimizer_name_ = training_optimizer_name - self.learning_rate_description_ = learning_rate_description - self.map_optimizer_attributes_ = map_optimizer_attributes - self.allreduce_post_accumulation_ = allreduce_post_accumulation - self.deepspeed_zero_stage_ = deepspeed_zero_stage - self.enable_grad_norm_clip_ = enable_grad_norm_clip - self.frozen_weights_ = frozen_weights - self.opset_version_ = _opset_version - self.state_dict_ = None - self._use_deterministic_compute = _use_deterministic_compute - self.use_memory_efficient_gradient = use_memory_efficient_gradient - self.run_symbolic_shape_infer = run_symbolic_shape_infer - self.enable_adasum = enable_adasum - self.optimized_model_filepath = optimized_model_filepath - - # use this special string to workaround a corner case that external loss_scale is passed into train_step as kwargs. - # see prepare_input_and_fetches for more details. - self.loss_scale_input_name = "default_loss_scale_input_name" - - self._init_session() - - def _init_session(self): - if self.onnx_model_ is None: - return - - self._verify_fully_optimized_model(self.onnx_model_) - - if self.run_symbolic_shape_infer: - self.onnx_model_ = SymbolicShapeInference.infer_shapes( - self.onnx_model_, auto_merge=True, guess_output_rank=True - ) - - # old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error. - # for example, load_state_dict will be called before returing the function, and it calls _init_session again - del self.session - ( - self.session, - self.train_io_binding, - self.eval_io_binding, - self.output_name, - _, - self.output_types, - ) = create_ort_training_session_with_optimizer( - self.onnx_model_, - self.device_, - self.training_optimizer_name_, - self.learning_rate_description_.name_, - self.map_optimizer_attributes_, - self.world_rank, - self.world_size, - self.gradient_accumulation_steps, - bind_parameters=False, - use_mixed_precision=self.use_mixed_precision, - allreduce_post_accumulation=self.allreduce_post_accumulation_, - deepspeed_zero_stage=self.deepspeed_zero_stage_, - enable_grad_norm_clip=self.enable_grad_norm_clip_, - frozen_weights=self.frozen_weights_, - opset_version=self.opset_version_, - use_deterministic_compute=self._use_deterministic_compute, - use_memory_efficient_gradient=self.use_memory_efficient_gradient, - enable_adasum=self.enable_adasum, - optimized_model_filepath=self.optimized_model_filepath, - ) - - self.loss_scale_input_name = self.session.loss_scale_input_name - - if self.use_mixed_precision: - self.input_desc_with_lr_and_loss_scale = [ - *self.input_desc_with_lr, - IODescription(self.loss_scale_input_name, [], torch.float32), - ] - - # ORT backend has modified model output dtype from float32 to float16. - for o_desc in self.model_desc_.outputs_: - if ( - self.use_mixed_precision - and o_desc.dtype_ == torch.float32 - and not self.session.is_output_fp32_node(o_desc.name_) - ): - o_desc.eval_dtype_ = torch.float16 - else: - o_desc.eval_dtype_ = o_desc.dtype_ - - # gradient accumulation buffers are connected to a single node with a boolean, dimension 1 tensor output. - # add a matching output to drive gradient accumulation. - if self.gradient_accumulation_steps > 1: - self.output_desc_with_group_accumulated_gradients = [ - *self.model_desc_.outputs_, - IODescription(get_group_accumulated_gradients_output_node_arg_name(self.session), [1], torch.bool), - ] - - if self.use_mixed_precision: - # when ready to use accumulated gradient with mixed precision, we need to fetch all_infinite to determine - # if the gradient is usable. - self.output_desc_with_all_fp_16_or_fp32_gradients_finite = [ - *self.model_desc_.outputs_, - IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool), - ] - - if self.state_dict_: - self.load_state_dict(self.state_dict_, self.strict_) - self.state_dict_ = None - - def _init_onnx_model(self, inputs): - if self.onnx_model_ is not None: - return - - if self.torch_model_ is not None: - # NOTE: pt model is moved to cpu to conserve gpu memory. - self.torch_model_.cpu() - # torch buffers created using 'register_buffer' are not meant to be trainable. - torch_buffers = list(dict(self.torch_model_.named_buffers()).keys()) - self.frozen_weights_ = self.frozen_weights_ + torch_buffers - self.onnx_model_ = convert_model_loss_fn_to_onnx( - self.torch_model_, - self.loss_fn_, - self.model_desc_, - torch.device("cpu"), - inputs, - opset_version=self.opset_version_, - ) - - if self._enable_internal_postprocess: - postprocess.run_postprocess(self.onnx_model_) - - if self._extra_postprocess: - self._extra_postprocess(self.onnx_model_) - - self._init_session() - - def train(self): - self.is_train = True - - def eval(self): - self.is_train = False - - def _update_onnx_model_initializers(self, state_tensors): - # replace the initializers with new value - new_weights = [] - replace_indices = [] - for i, w in enumerate(self.onnx_model_.graph.initializer): - if w.name in state_tensors: - new_weights.append(numpy_helper.from_array(state_tensors[w.name], w.name)) - replace_indices.append(i) - replace_indices.sort(reverse=True) - for w_i in replace_indices: - del self.onnx_model_.graph.initializer[w_i] - self.onnx_model_.graph.initializer.extend(new_weights) - - def state_dict(self, include_optimizer_state=True): - if not self.session: - warnings.warn( - "ONNXRuntime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling state_dict()." - ) - return {} - - # extract trained weights - session_state = self.session.get_state() - torch_state = {} - for name in session_state: - torch_state[name] = torch.from_numpy(session_state[name]) - - # extract untrained weights and buffer - for n in self.onnx_model_.graph.initializer: - if n.name not in torch_state: - torch_state[n.name] = torch.from_numpy(numpy_helper.to_array(n)) - - # Need to remove redundant initializers and name suffices to map back to original torch state names - if not include_optimizer_state and self._torch_state_dict_keys: - return {key: torch_state[key] for key in self._torch_state_dict_keys if key in torch_state} - return torch_state - - def load_state_dict(self, state_dict, strict=False): - # Note: It may happen ONNX model has not yet been initialized - # In this case we cache a reference to desired state and delay the restore until after initialization - # Unexpected behavior will result if the user changes the reference before initialization - if not self.session: - self.state_dict_ = state_dict - self.strict_ = strict - return - - # update onnx model from loaded state dict - cur_initializers_names = [n.name for n in self.onnx_model_.graph.initializer] - new_initializers = {} - - for name in state_dict: - if name in cur_initializers_names: - new_initializers[name] = state_dict[name].numpy() - elif strict: - raise RuntimeError(f"Checkpoint tensor: {name} is not present in the model.") - - self._update_onnx_model_initializers(new_initializers) - - # create new session based on updated onnx model - self.state_dict_ = None - self._init_session() - - # load training state - session_state = {name: state_dict[name].numpy() for name in state_dict} - self.session.load_state(session_state, strict) - - def save_as_onnx(self, path): - if not self.session: - warnings.warn( - "ONNXRuntime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling save_as_onnx()." - ) - return - state_tensors = self.session.get_state() - self._update_onnx_model_initializers(state_tensors) - - with open(path, "wb") as f: - f.write(self.onnx_model_.SerializeToString()) - - def _prepare_input_and_fetches( - self, input_desc_with_, internal_learning_rate, internal_loss_scale, *args, **kwargs - ): - fetches = None - if type(args) == tuple and len(args) == 1 and type(args[0]) == list: # noqa: E721 - input = tuple(args[0]) - else: - input = args - - for input_desc in input_desc_with_: - if input_desc.name_ in kwargs: - input = (*input, kwargs[input_desc.name_]) - if internal_learning_rate is not None: - input = (*input, internal_learning_rate) - if internal_loss_scale is not None: - input = (*input, internal_loss_scale) - elif self.use_mixed_precision: - # loss_scale input name is needed to call train_step, for example: - # kwargs[model.loss_scale_input_name] = loss_scale - # outputs = model.train_step(*args, **kwargs) - # However, when first time train_step is called model.loss_scale_input_name is not set. - # To workaround this problem, we use the special name 'default_loss_scale_input_name' to indicate - # the loss_scale. - if "default_loss_scale_input_name" in kwargs: - input = (*input, kwargs["default_loss_scale_input_name"]) - - fetches = None - if "fetches" in kwargs: - fetches = kwargs["fetches"] - - return input, fetches - - def train_step(self, *args, **kwargs): - """ - inputs: model inputs, labels, learning rate, and, if in mixed_precision mode, loss_scale. - outputs: if fetches is not provided, outputs are loss and - (if in mixed mode and is finishing gradient accumulation) all_finite. - if fetches is provided, outputs contains these requested with fetches. - fetches: names of requested outputs - """ - - # inputs to the ONNX model includes inputs to the original PyTorch model - # plus learning rate and loss_scale if self.use_mixed_precision is True. - # 1. when there are internal learning_rate and loss_scale (in fp16 cases) generators, - # *args and **kwargs together contain ONLY and COMPLETE inputs to the PyTorch model. - # In this case, changes to the training script is minimized. - # 2. without internal learning rate and loss scale (in fp16 cases) generators, - # *args and **kwargs passed in from the training script shall contains - # inputs to the PyTorch model plus learning_rate and loss_scale. - # it optionally contains the fetches. - # localized arguments (*args) contains inputs to the ONNX model. - # named arguments can contain both inputs, learning_rate and loss_scale, and the fetches - - learning_rate, loss_scale = None, None - if self.get_lr_this_step_ is not None: - # $args, **kwargs contains inputs to the pytorch model - lr_this_step = self.get_lr_this_step_(self.global_step_) - learning_rate = torch.tensor([lr_this_step]) - if self.loss_scaler_ is not None and self.use_mixed_precision: - loss_scale = torch.tensor([self.loss_scaler_.loss_scale_]) - - if self.onnx_model_ is None: - sample_input, _ = self._prepare_input_and_fetches(self.model_desc_.inputs_, None, None, *args, **kwargs) - self._init_onnx_model(sample_input) - - if self.use_mixed_precision: - input, fetches = self._prepare_input_and_fetches( - self.input_desc_with_lr_and_loss_scale, learning_rate, loss_scale, *args, **kwargs - ) - assert len(self.input_desc_with_lr_and_loss_scale) == len(input) - input_descs = self.input_desc_with_lr_and_loss_scale - else: - input, fetches = self._prepare_input_and_fetches( - self.input_desc_with_lr, learning_rate, loss_scale, *args, **kwargs - ) - assert len(self.input_desc_with_lr) == len(input) - input_descs = self.input_desc_with_lr - - self.current_step += 1 - - # handle gradient accumulation in fully optimized mode - run_options = None - has_if_all_finite = False - if fetches: - output_desc = [output for fetch in fetches for output in self.model_desc_.outputs_ if output.name_ == fetch] - elif self.current_step % self.gradient_accumulation_steps != 0: - run_options = ort.RunOptions() - run_options.only_execute_path_to_fetches = True - output_desc = self.output_desc_with_group_accumulated_gradients - elif self.use_mixed_precision: - has_if_all_finite = True - output_desc = self.output_desc_with_all_fp_16_or_fp32_gradients_finite - else: - output_desc = self.model_desc_.outputs_ - - if not isinstance(input, (list, tuple)): - input = (input,) - - session_run_results = ort_training_session_run_helper( - self.session, self.train_io_binding, input, input_descs, output_desc, self.device_, run_options - ) - - if has_if_all_finite: - # After session run with all_fp32_gradients_finite, we need to clear the iobinding's output state. - # Otherwise next run with only_execute_path_to_fetches will lead to gradient all reduce - # because all_fp32_gradients_finite is still in the feed. - self.train_io_binding.clear_binding_outputs() - all_finite = session_run_results[self.output_desc_with_all_fp_16_or_fp32_gradients_finite[-1].name_] - if self.loss_scaler_ is not None: - self.loss_scaler_.update_loss_scale(all_finite) - if all_finite: - # optimization has done, increase self.global_step_ - self.global_step_ = self.global_step_ + 1 - elif self.current_step % self.gradient_accumulation_steps == 0: - # optimization has done, increase self.global_step_ - self.global_step_ = self.global_step_ + 1 - - if fetches is not None: - results = [session_run_results[fetch] for fetch in fetches] - elif has_if_all_finite and self.loss_scaler_ is None: - # return descripted outputs plus the all_finite flag so that the training script can handle loss scaling. - results = [ - session_run_results[output_desc.name_] - for output_desc in self.output_desc_with_all_fp_16_or_fp32_gradients_finite - ] - else: - results = [session_run_results[output_desc.name_] for output_desc in self.model_desc_.outputs_] - return results[0] if len(results) == 1 else results - - def __call__(self, *args, **kwargs): - if self.is_train: - return self.train_step(*args, **kwargs) - else: - return self.eval_step(*args, **kwargs) - - def eval_step(self, *args, **kwargs): - """ - inputs: model inputs and/or labels. - outputs: if 'fetches' is not provided, outputs are loss and - (if in mixed mode and is finishing gradient accumulation) all_finite. - if fetches is provided, outputs contains these requested with fetches. - fetches: names of requested outputs - """ - - # with model_loss_cls, the last input is label, first output is loss - input, fetches = self._prepare_input_and_fetches(self.model_desc_.inputs_, None, None, *args, **kwargs) - - if self.onnx_model_ is None: - if self.torch_model_ is not None: - self._init_onnx_model(input) - else: - raise RuntimeError( - "Model is unintialized. Please ensure a valid ONNX model or PyTorch model is provided to this Trainer." - ) - - input_desc = self.model_desc_.inputs_[0 : len(input)] - if fetches is None: - output_desc = self.model_desc_.outputs_ - else: - output_desc = [output for fetch in fetches for output in self.model_desc_.outputs_ if output.name_ == fetch] - - if not isinstance(input, (list, tuple)): - input = (input,) - - run_options = ort.RunOptions() - run_options.only_execute_path_to_fetches = True - run_options.training_mode = False - - session_run_results = ort_training_session_run_helper( - self.session, self.eval_io_binding, input, input_desc, output_desc, self.device_, run_options - ) - - if len(session_run_results) == 1: - return session_run_results[next(iter(session_run_results.keys()))] - else: - return [session_run_results[output_desc.name_] for output_desc in output_desc] - - def _verify_fully_optimized_model(self, model): - assert len(model.graph.output) > 0 - # model's first output must be the loss tensor - if model.graph.output[0].type.tensor_type.elem_type not in { - onnx.TensorProto.FLOAT, - onnx.TensorProto.FLOAT16, - onnx.TensorProto.DOUBLE, - onnx.TensorProto.COMPLEX64, - onnx.TensorProto.COMPLEX128, - onnx.TensorProto.BFLOAT16, - onnx.TensorProto.FLOAT8E4M3FN, - onnx.TensorProto.FLOAT8E4M3FNUZ, - onnx.TensorProto.FLOAT8E5M2, - onnx.TensorProto.FLOAT8E5M2FNUZ, - }: - raise RuntimeError( - "the first output of a model to run with fully optimized ORT backend must be float types." - ) - if len(model.graph.output[0].type.tensor_type.shape.dim) != 0: - raise RuntimeError( - "the first output of a model to run with fully optimized ORT backend assumed to be loss and must be a scalar." - ) - - -class LossScaler: - def __init__( - self, - loss_scale_input_name, - is_dynamic_scale, - loss_scale=float(1 << 16), - up_scale_window=2000, - min_loss_scale=1.0, - max_loss_scale=float(1 << 24), - ): - super().__init__() - self.loss_scale_input_name_ = loss_scale_input_name - self.is_dynamic_scale_ = is_dynamic_scale - self.initial_loss_scale_ = loss_scale - self.up_scale_window_ = up_scale_window - self.min_loss_scale_ = min_loss_scale - self.max_loss_scale_ = max_loss_scale - self.loss_scale_ = loss_scale - self.stable_steps_ = 0 - - def update_loss_scale(self, is_all_finite): - if not self.is_dynamic_scale_: - return - - if is_all_finite: - self.stable_steps_ += 1 - - if self.stable_steps_ >= self.up_scale_window_: - self.loss_scale_ = min(self.max_loss_scale_, self.loss_scale_ * 2) - self.stable_steps_ = 0 - else: - self.loss_scale_ = max(self.min_loss_scale_, self.loss_scale_ / 2) - self.stable_steps_ = 0 - - def reset(self): - self.loss_scale_ = self.initial_loss_scale_ - self.stable_steps_ = 0 diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index a08e8bee99cee..bb1cb4bbd32f7 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -18,7 +18,6 @@ #include "core/session/environment.h" #include "core/session/custom_ops.h" #include "core/dlpack/dlpack_converter.h" -#include "orttraining/core/session/training_session.h" #include "orttraining/core/agent/training_agent.h" #include "orttraining/core/graph/gradient_config.h" #include "orttraining/core/graph/optimizer_config.h" @@ -113,14 +112,11 @@ struct TrainingParameters { std::unordered_set weights_to_train; std::unordered_set weights_not_to_train; - onnxruntime::training::TrainingSession::ImmutableWeights immutable_weights; - // optimizer std::string training_optimizer_name; std::string lr_params_feed_name = "Learning_Rate"; std::unordered_map> optimizer_attributes_map; std::unordered_map> optimizer_int_attributes_map; - onnxruntime::training::TrainingSession::OptimizerState optimizer_initial_state; std::unordered_map> sliced_schema; std::unordered_map sliced_axes; std::vector sliced_tensor_names; @@ -206,185 +202,6 @@ struct PyGradientGraphBuilderContext { local_registries_(local_registries) {} }; -// TODO: this method does not handle parallel optimization. -TrainingConfigurationResult ConfigureSessionForTraining( - training::PipelineTrainingSession* sess, TrainingParameters& parameters) { - // TODO tix, refactor the mpi related code to populate all fields correctly by default. - ORT_ENFORCE(parameters.data_parallel_size <= parameters.world_size, "data_parallel_size: ", parameters.data_parallel_size, ", world_size: ", parameters.world_size); - ORT_ENFORCE(parameters.horizontal_parallel_size <= parameters.world_size, "horizontal_parallel_size: ", parameters.horizontal_parallel_size, ", world_size: ", parameters.world_size); - ORT_ENFORCE(parameters.pipeline_parallel_size <= parameters.world_size, "pipeline_parallel_size: ", parameters.pipeline_parallel_size, ", world_size: ", parameters.world_size); - - // When DxHxP != the total number of ranks, we try adjusting D so that DxHxP == the total number of ranks. - if (parameters.world_size != parameters.data_parallel_size * parameters.horizontal_parallel_size * parameters.pipeline_parallel_size) { - ORT_ENFORCE(parameters.world_size % parameters.horizontal_parallel_size * parameters.pipeline_parallel_size == 0, - "D, H, P sizes are incorrect. To enable automatic correction, total number of ranks must be a divisible by HxP."); - - const auto new_data_parallel_size = parameters.world_size / (parameters.horizontal_parallel_size * parameters.pipeline_parallel_size); - parameters.data_parallel_size = new_data_parallel_size; - - const std::string msg = "Cannot distribute " + std::to_string(parameters.world_size) + " ranks for distributed computation with D=" + std::to_string(parameters.data_parallel_size) + - ", H=" + std::to_string(parameters.horizontal_parallel_size) + ", P=" + std::to_string(parameters.pipeline_parallel_size) + ", so D is automatically changed to " + std::to_string(new_data_parallel_size); - LOGS(*(sess->GetLogger()), WARNING) << msg; - } - - training::PipelineTrainingSession::TrainingConfiguration config{}; - config.weight_names_to_train = parameters.weights_to_train; - config.weight_names_to_not_train = parameters.weights_not_to_train; - config.immutable_weights = parameters.immutable_weights; - config.gradient_accumulation_steps = parameters.gradient_accumulation_steps; - - config.distributed_config.world_rank = parameters.world_rank; - config.distributed_config.world_size = parameters.world_size; - config.distributed_config.local_rank = parameters.local_rank; - config.distributed_config.local_size = parameters.local_size; - config.distributed_config.data_parallel_size = parameters.data_parallel_size; - config.distributed_config.horizontal_parallel_size = parameters.horizontal_parallel_size; - config.distributed_config.pipeline_parallel_size = parameters.pipeline_parallel_size; - config.distributed_config.num_pipeline_micro_batches = parameters.num_pipeline_micro_batches; - config.distributed_config.sliced_schema = parameters.sliced_schema; - config.distributed_config.sliced_axes = parameters.sliced_axes; - config.distributed_config.sliced_tensor_names = parameters.sliced_tensor_names; - - if (parameters.use_mixed_precision) { - training::PipelineTrainingSession::TrainingConfiguration::MixedPrecisionConfiguration mp{}; - mp.use_mixed_precision_initializers = true; - - config.mixed_precision_config = mp; - } - - if (config.distributed_config.pipeline_parallel_size > 1) { - training::PipelineTrainingSession::TrainingConfiguration::PipelineConfiguration pipeline_config; - - // Currently don't support auto-partition. User needs to pass in cut information for pipeline - pipeline_config.do_partition = true; - assert(!parameters.pipeline_cut_info_string.empty()); - - auto process_with_delimiter = [](std::string& input_str, const std::string& delimiter) { - std::vector result; - size_t pos = 0; - while ((pos = input_str.find(delimiter)) != std::string::npos) { - std::string token = input_str.substr(0, pos); - result.emplace_back(token); - input_str.erase(0, pos + delimiter.length()); - } - // push the last split of substring into result. - result.emplace_back(input_str); - return result; - }; - - auto process_cut_info = [&](std::string& cut_info_string) { - std::vector cut_list; - const std::string group_delimiter = ","; - const std::string edge_delimiter = ":"; - const std::string consumer_delimiter = "/"; - const std::string producer_consumer_delimiter = "-"; - - auto cut_info_groups = process_with_delimiter(cut_info_string, group_delimiter); - for (auto& cut_info_group : cut_info_groups) { - PipelineTrainingSession::TrainingConfiguration::CutInfo cut_info; - auto cut_edges = process_with_delimiter(cut_info_group, edge_delimiter); - for (auto& cut_edge : cut_edges) { - auto process_edge = process_with_delimiter(cut_edge, producer_consumer_delimiter); - if (process_edge.size() == 1) { - PipelineTrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0]}; - cut_info.emplace_back(edge); - } else { - ORT_ENFORCE(process_edge.size() == 2); - auto consumer_list = process_with_delimiter(process_edge[1], consumer_delimiter); - - PipelineTrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0], consumer_list}; - cut_info.emplace_back(edge); - } - } - cut_list.emplace_back(cut_info); - } - return cut_list; - }; - - pipeline_config.cut_list = process_cut_info(parameters.pipeline_cut_info_string); - config.pipeline_config = pipeline_config; - } - config.loss_name = parameters.loss_output_name; - - if (!parameters.training_optimizer_name.empty()) { - training::PipelineTrainingSession::TrainingConfiguration::OptimizerConfiguration opt{}; - opt.name = parameters.training_optimizer_name; - opt.learning_rate_input_name = parameters.lr_params_feed_name; - opt.weight_attributes_generator = [¶meters](const std::string& weight_name) { - const auto it = parameters.optimizer_attributes_map.find(weight_name); - ORT_ENFORCE( - it != parameters.optimizer_attributes_map.end(), - "Failed to find attribute map for weight ", weight_name); - return it->second; - }; - opt.weight_int_attributes_generator = [¶meters](const std::string& weight_name) { - const auto it = parameters.optimizer_int_attributes_map.find(weight_name); - ORT_ENFORCE( - it != parameters.optimizer_int_attributes_map.end(), - "Failed to find int attribute map for weight ", weight_name); - return it->second; - }; - opt.use_mixed_precision_moments = parameters.use_fp16_moments; - opt.do_all_reduce_in_mixed_precision_type = true; - // TODO: this mapping is temporary. - // For now, nccl allreduce kernel only implements for allreduce_post_accumulation - // hovorod allreduce kernel only implements for not allreduce_post_accumulation. - // eventually we will have one all reduce kernel and let opt to have - // an allreduce_post_accumulation option and remove the use_nccl option. - opt.use_nccl = parameters.allreduce_post_accumulation; - opt.deepspeed_zero = onnxruntime::training::ZeROConfig(parameters.deepspeed_zero_stage); - opt.enable_grad_norm_clip = parameters.enable_grad_norm_clip; - - // TODO reduction types - if (parameters.enable_adasum) { -#ifdef USE_CUDA - opt.adasum_reduction_type = training::AdasumReductionType::GpuHierarchicalReduction; -#else - opt.adasum_reduction_type = training::AdasumReductionType::CpuReduction; -#endif - } - - config.optimizer_config = opt; - } - - if (!parameters.optimizer_initial_state.empty()) { - config.init_optimizer_states = parameters.optimizer_initial_state; - } - - config.gradient_graph_config.use_memory_efficient_gradient = parameters.use_memory_efficient_gradient; - config.gradient_graph_config.set_gradients_as_graph_outputs = parameters.set_gradients_as_graph_outputs; - - config.graph_transformer_config.attn_dropout_recompute = parameters.attn_dropout_recompute; - config.graph_transformer_config.gelu_recompute = parameters.gelu_recompute; - config.graph_transformer_config.transformer_layer_recompute = parameters.transformer_layer_recompute; - config.graph_transformer_config.number_recompute_layers = parameters.number_recompute_layers; - config.graph_transformer_config.propagate_cast_ops_config.strategy = parameters.propagate_cast_ops_strategy; - config.graph_transformer_config.propagate_cast_ops_config.level = parameters.propagate_cast_ops_level; - config.graph_transformer_config.propagate_cast_ops_config.allow = parameters.propagate_cast_ops_allow; - - if (!parameters.model_after_graph_transforms_path.empty()) { - config.model_after_graph_transforms_path = ToPathString(parameters.model_after_graph_transforms_path); - } - if (!parameters.model_with_gradient_graph_path.empty()) { - config.model_with_gradient_graph_path = ToPathString(parameters.model_with_gradient_graph_path); - } - if (!parameters.model_with_training_graph_path.empty()) { - config.model_with_training_graph_path = ToPathString(parameters.model_with_training_graph_path); - } - - training::PipelineTrainingSession::TrainingConfigurationResult config_result{}; - - OrtPybindThrowIfError(sess->ConfigureForTraining(config, config_result)); - - TrainingConfigurationResult python_config_result{}; - if (config_result.mixed_precision_config_result.has_value()) { - const auto& mp_config_result = config_result.mixed_precision_config_result.value(); - python_config_result.loss_scale_input_name = mp_config_result.loss_scale_input_name; - } - - return python_config_result; -} - #if defined(USE_MPI) void CopyMPIContextToTrainingParameters(TrainingParameters& parameters, const logging::Logger* logger) { LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldRank(): " << MPIContext::GetInstance().GetWorldRank(); @@ -424,7 +241,7 @@ std::unordered_map> Con return py_tensor_state; } -void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn) { +void addObjectMethodsForTraining(py::module& m) { py::class_(m, "OrtValueCache") .def(py::init<>()) .def("insert", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name, OrtValue& value) { @@ -451,7 +268,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn py::class_ parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc"); parameters.def(py::init()) .def_readwrite("loss_output_name", &TrainingParameters::loss_output_name) - .def_readwrite("immutable_weights", &TrainingParameters::immutable_weights) .def_readwrite("weights_not_to_train", &TrainingParameters::weights_not_to_train) .def_readwrite("weights_to_train", &TrainingParameters::weights_to_train) .def_readwrite("sliced_tensor_names", &TrainingParameters::sliced_tensor_names) @@ -484,25 +300,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn .def_readwrite("data_parallel_size", &TrainingParameters::data_parallel_size) .def_readwrite("horizontal_parallel_size", &TrainingParameters::horizontal_parallel_size) .def_readwrite("pipeline_parallel_size", &TrainingParameters::pipeline_parallel_size) - .def("set_optimizer_initial_state", - [](TrainingParameters& parameters, const std::unordered_map>& py_state) -> void { - onnxruntime::training::TrainingSession::OptimizerState optim_state; - for (const auto& weight_it : py_state) { - auto state = weight_it.second; - NameMLValMap state_tensors; - for (auto& initializer : state) { - OrtValue ml_value; - - // InputDeflist is null because parameters havent been tied to session yet - // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) - CreateGenericMLValue(nullptr, GetAllocator(), "", initializer.second, &ml_value, true); - ThrowIfPyErrOccured(); - state_tensors.emplace(initializer.first, ml_value); - } - optim_state.emplace(weight_it.first, state_tensors); - } - parameters.optimizer_initial_state = optim_state; - }) .def_readwrite("model_after_graph_transforms_path", &TrainingParameters::model_after_graph_transforms_path) .def_readwrite("model_with_gradient_graph_path", &TrainingParameters::model_with_gradient_graph_path) .def_readwrite("model_with_training_graph_path", &TrainingParameters::model_with_training_graph_path) @@ -611,130 +408,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn }); #endif - py::class_ config_result(m, "TrainingConfigurationResult", "pbdoc(Configuration result for training.)pbdoc"); - config_result.def(py::init()) - .def_property_readonly("loss_scale_input_name", [](const TrainingConfigurationResult& result) -> py::object { - if (result.loss_scale_input_name.has_value()) { - return py::str{result.loss_scale_input_name.value()}; - } - return py::none(); - }); - - // Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user - struct PyTrainingSession : public PyInferenceSession { - PyTrainingSession(std::shared_ptr env, const PySessionOptions& so) - : PyInferenceSession(env, std::make_unique(so.value, *env)) { - } - ~PyTrainingSession() = default; - }; - - py::class_ training_session(m, "TrainingSession"); - training_session - .def(py::init([](const PySessionOptions& so) { - auto& training_env = GetTrainingEnv(); - return std::make_unique(training_env.GetORTEnv(), so); - })) - .def(py::init([]() { - auto& training_env = GetTrainingEnv(); - return std::make_unique(training_env.GetORTEnv(), GetDefaultCPUSessionOptions()); - })) - .def("finalize", [](py::object) { -#if defined(USE_MPI) -#ifdef _WIN32 - // https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices - // shutdown_mpi() is not called within MPIContext destructor because of DllMain's restriction - // call shutdown_mpi() here instead. - MPIContext::shutdown_mpi(); -#endif -#endif - }) - .def("load_model", [ep_registration_fn](PyTrainingSession* sess, const std::string& path, TrainingParameters& parameters, const std::vector& provider_types, const ProviderOptionsVector& provider_options) { - OrtPybindThrowIfError(sess->GetSessionHandle()->Load(path)); - -#if defined(USE_MPI) - bool use_nccl = parameters.allreduce_post_accumulation; - if (!use_nccl && parameters.world_size > 1) - CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger()); -#endif - const auto config_result = ConfigureSessionForTraining(static_cast(sess->GetSessionHandle()), parameters); - - ProviderOptionsVector merged_options; - ResolveExtraProviderOptions(provider_types, provider_options, merged_options); - - InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options); - - return config_result; - }) - .def("read_bytes", [ep_registration_fn](PyTrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters, const std::vector& provider_types, const ProviderOptionsVector& provider_options) { - std::istringstream buffer(serialized_model); - OrtPybindThrowIfError(sess->GetSessionHandle()->Load(buffer)); - -#if defined(USE_MPI) - bool use_nccl = parameters.allreduce_post_accumulation; - if (!use_nccl && parameters.world_size > 1) - CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger()); -#endif - const auto config_result = ConfigureSessionForTraining(static_cast(sess->GetSessionHandle()), parameters); - ProviderOptionsVector merged_options; - ResolveExtraProviderOptions(provider_types, provider_options, merged_options); - - InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options); - - return config_result; - }) - .def("get_state", [](PyTrainingSession* sess) { - NameMLValMap state_tensors; - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->GetStateTensors(state_tensors)); - auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager(); - // convert to numpy array - std::map rmap; - for (auto& kv : state_tensors) { - if (kv.second.IsTensor()) { - py::object obj; - const Tensor& rtensor = kv.second.Get(); - GetPyObjFromTensor(rtensor, obj, &data_transfer_manager); - rmap.insert({kv.first, obj}); - } else { - throw std::runtime_error("Non tensor type in session state tensors is not expected."); - } - } - return rmap; - }) - .def("get_model_state", [](PyTrainingSession* sess, bool include_mixed_precision_weights) { - std::unordered_map model_state_tensors; - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->GetModelState(model_state_tensors, include_mixed_precision_weights)); - auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager(); - return ConvertORTTensorMapToNumpy(model_state_tensors, data_transfer_manager); - }) - .def("get_optimizer_state", [](PyTrainingSession* sess) { - std::unordered_map opt_state_tensors; - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->GetOptimizerState(opt_state_tensors)); - auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager(); - return ConvertORTTensorMapToNumpy(opt_state_tensors, data_transfer_manager); - }) - .def("get_partition_info_map", [](PyTrainingSession* sess) { - std::unordered_map>> part_info_map; - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->GetPartitionInfoMap(part_info_map)); - return part_info_map; - }) - .def("load_state", [](PyTrainingSession* sess, std::unordered_map& state, bool strict) { - NameMLValMap state_tensors; - for (auto initializer : state) { - OrtValue ml_value; - auto px = sess->GetSessionHandle()->GetModelInputs(); - if (!px.first.IsOK() || !px.second) { - throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null"); - } - CreateGenericMLValue(px.second, GetAllocator(), initializer.first, initializer.second, &ml_value); - ThrowIfPyErrOccured(); - state_tensors.insert(std::make_pair(initializer.first, ml_value)); - } - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->SetStateTensors(state_tensors, strict)); - }) - .def("is_output_fp32_node", [](PyTrainingSession* sess, const std::string& output_name) { - return static_cast(sess->GetSessionHandle())->IsGraphOutputFp32Node(output_name); - }); - py::class_(m, "PartialGraphExecutionState") .def(py::init([]() { return std::make_unique(); diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index 88ef90a7feaa8..4d1db7334f280 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -40,7 +40,7 @@ const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM* void addGlobalMethods(py::module& m); void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn); -void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn); +void addObjectMethodsForTraining(py::module& m); void addObjectMethodsForEager(py::module& m); #ifdef ENABLE_LAZY_TENSOR void addObjectMethodsForLazyTensor(py::module& m); @@ -339,7 +339,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { } #endif - addObjectMethodsForTraining(m, ORTTrainingRegisterExecutionProviders); + addObjectMethodsForTraining(m); #ifdef ENABLE_LAZY_TENSOR addObjectMethodsForLazyTensor(m); diff --git a/orttraining/orttraining/python/training/__init__.py b/orttraining/orttraining/python/training/__init__.py index 73b1f826f68e1..a3c22686a1039 100644 --- a/orttraining/orttraining/python/training/__init__.py +++ b/orttraining/orttraining/python/training/__init__.py @@ -8,26 +8,16 @@ TrainingParameters, is_ortmodule_available, ) -from onnxruntime.capi.training.training_session import TrainingSession - # Options need to be imported before `ORTTrainer`. -from .orttrainer_options import ORTTrainerOptions -from .orttrainer import ORTTrainer, TrainStepInfo -from . import amp, artifacts, checkpoint, model_desc_validation, optim +from . import amp, artifacts, optim __all__ = [ "PropagateCastOpsStrategy", "TrainingParameters", "is_ortmodule_available", - "TrainingSession", - "ORTTrainerOptions", - "ORTTrainer", - "TrainStepInfo", "amp", "artifacts", - "checkpoint", - "model_desc_validation", "optim", ] diff --git a/orttraining/orttraining/python/training/_checkpoint_storage.py b/orttraining/orttraining/python/training/_checkpoint_storage.py deleted file mode 100644 index 7a8ada7dee96b..0000000000000 --- a/orttraining/orttraining/python/training/_checkpoint_storage.py +++ /dev/null @@ -1,107 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import pickle -from collections.abc import Mapping - -import h5py - - -def _dfs_save(group, save_obj): - """Recursively go over each level in the save_obj dictionary and save values to a hdf5 group""" - - for key, value in save_obj.items(): - if isinstance(value, Mapping): - subgroup = group.create_group(key) - _dfs_save(subgroup, value) - else: - group[key] = value - - -def save(save_obj: dict, path): - """Persists the input dictionary to a file specified by path. - - Saves an hdf5 representation of the save_obj dictionary to a file or a file-like object specified by path. - Values are saved in a format supported by h5py. For example, a PyTorch tensor is saved and loaded as a - numpy object. So, user types may be converted from their original types to numpy equivalent types. - - Args: - save_obj: dictionary that needs to be saved. - save_obj should consist of types supported by hdf5 file format. - if hdf5 does not recognize a type, an exception is raised. - if save_obj is not a dictionary, a ValueError is raised. - path: string representation to a file path or a python file-like object. - if file already exists at path, an exception is raised. - """ - if not isinstance(save_obj, Mapping): - raise ValueError("Object to be saved must be a dictionary") - - with h5py.File(path, "w-") as f: - _dfs_save(f, save_obj) - - -def _dfs_load(group, load_obj): - """Recursively go over each level in the hdf5 group and load the values into the given dictionary""" - - for key in group: - if isinstance(group[key], h5py.Group): - load_obj[key] = {} - _dfs_load(group[key], load_obj[key]) - else: - load_obj[key] = group[key][()] - - -def load(path, key=None): - """Loads the data stored in the binary file specified at the given path into a dictionary and returns it. - - Loads the data from an hdf5 file specified at the given path into a python dictionary. - Loaded dictionary contains numpy equivalents of python data types. For example: - PyTorch tensor -> saved as a numpy array and loaded as a numpy array. - bool -> saved as a numpy bool and loaded as a numpy bool - If a '/' separated key is provided, the value at that hierarchical level in the hdf5 group is returned. - - Args: - path: string representation to a file path or a python file-like object. - if file does not already exist at path, an exception is raised. - key: '/' separated representation of the hierarchy level value that needs to be returned/ - for example, if the saved binary file has structure {a: {b: x, c:y}} and the user would like - to query the value for c, the key provided should be 'a/c'. - the default value of None for key implies that the entire hdf5 file structure needs to be loaded into a dictionary and returned. - - Returns: - a dictionary loaded from the specified binary hdf5 file. - """ - if not h5py.is_hdf5(path): - raise ValueError(f"{path} is not an hdf5 file or a python file-like object.") - - load_obj = {} - with h5py.File(path, "r") as f: - if key: - f = f[key] # noqa: PLW2901 - if isinstance(f, h5py.Dataset): - return f[()] - - _dfs_load(f, load_obj) - - return load_obj - - -def to_serialized_hex(user_dict): - """Serialize the user_dict and convert the serialized bytes to a hex string and return""" - - return pickle.dumps(user_dict).hex() - - -def from_serialized_hex(serialized_hex): - """Convert serialized_hex to bytes and deserialize it and return""" - - # serialized_hex can be either a regular string or a byte string. - # if it is a byte string, convert to regular string using decode() - # if it is a regular string, do nothing to it - try: # noqa: SIM105 - serialized_hex = serialized_hex.decode() - except AttributeError: - pass - return pickle.loads(bytes.fromhex(serialized_hex)) diff --git a/orttraining/orttraining/python/training/_utils.py b/orttraining/orttraining/python/training/_utils.py index 4eb79443c8f1a..091274d1d171d 100644 --- a/orttraining/orttraining/python/training/_utils.py +++ b/orttraining/orttraining/python/training/_utils.py @@ -6,11 +6,9 @@ import importlib.util import os import sys -from functools import wraps # noqa: F401 import numpy as np import torch -from onnx import TensorProto # noqa: F401 from packaging.version import Version @@ -23,16 +21,6 @@ def get_device_index(device): return 0 if device.index is None else device.index -def get_device_index_from_input(input): - """Returns device index from a input PyTorch Tensor""" - - if isinstance(input, (list, tuple)): - device_index = get_device_index(input[0].device) - else: - device_index = get_device_index(input.device) - return device_index - - def get_device_str(device): if isinstance(device, str): # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 @@ -50,24 +38,6 @@ def get_device_str(device): return device -def get_all_gradients_finite_name_from_session(session): - """Find all_gradients_finite node on Session graph and return its name""" - - nodes = [x for x in session._outputs_meta if "all_gradients_finite" in x.name] - if len(nodes) != 1: - raise RuntimeError("'all_gradients_finite' node not found within training session") - return nodes[0].name - - -def get_gradient_accumulation_name_from_session(session): - """Find Group_Accumulated_Gradients node on Session graph and return its name""" - - nodes = [x for x in session._outputs_meta if "Group_Accumulated_Gradients" in x.name] - if len(nodes) != 1: - raise RuntimeError("'Group_Accumulated_Gradients' node not found within training session") - return nodes[0].name - - def dtype_torch_to_numpy(torch_dtype): """Converts PyTorch types to Numpy types @@ -232,111 +202,3 @@ def import_module_from_file(file_path, module_name=None): sys.modules[module_name] = module spec.loader.exec_module(module) return module - - -def state_dict_model_key(): - """Returns the model key name in the state dictionary""" - - return "model" - - -def state_dict_optimizer_key(): - """Returns the optimizer key name in the state dictionary""" - - return "optimizer" - - -def state_dict_partition_info_key(): - """Returns the partition info key name in the state dictionary""" - - return "partition_info" - - -def state_dict_trainer_options_key(): - """Returns the trainer options key name in the state dictionary""" - - return "trainer_options" - - -def state_dict_full_precision_key(): - """Returns the full precision key name in the state dictionary""" - - return "full_precision" - - -def state_dict_original_dimension_key(): - """Returns the original dimension key name in the state dictionary""" - - return "original_dim" - - -def state_dict_sharded_optimizer_keys(): - """Returns the optimizer key names that can be sharded in the state dictionary""" - - return {"Moment_1", "Moment_2"} - - -def state_dict_user_dict_key(): - """Returns the user dict key name in the state dictionary""" - - return "user_dict" - - -def state_dict_trainer_options_mixed_precision_key(): - """Returns the trainer options mixed precision key name in the state dictionary""" - - return "mixed_precision" - - -def state_dict_trainer_options_zero_stage_key(): - """Returns the trainer options zero_stage key name in the state dictionary""" - - return "zero_stage" - - -def state_dict_trainer_options_world_rank_key(): - """Returns the trainer options world_rank key name in the state dictionary""" - - return "world_rank" - - -def state_dict_trainer_options_world_size_key(): - """Returns the trainer options world_size key name in the state dictionary""" - - return "world_size" - - -def state_dict_trainer_options_data_parallel_size_key(): - """Returns the trainer options data_parallel_size key name in the state dictionary""" - - return "data_parallel_size" - - -def state_dict_trainer_options_horizontal_parallel_size_key(): - """Returns the trainer options horizontal_parallel_size key name in the state dictionary""" - - return "horizontal_parallel_size" - - -def state_dict_trainer_options_optimizer_name_key(): - """Returns the trainer options optimizer_name key name in the state dictionary""" - - return "optimizer_name" - - -def state_dict_train_step_info_key(): - """Returns the train step info key name in the state dictionary""" - - return "train_step_info" - - -def state_dict_train_step_info_optimization_step_key(): - """Returns the train step info optimization step key name in the state dictionary""" - - return "optimization_step" - - -def state_dict_train_step_info_step_key(): - """Returns the train step info step key name in the state dictionary""" - - return "step" diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index 549614de496a6..a57105545e114 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -53,6 +53,8 @@ def generate_artifacts( 3. Checkpoint (directory): Contains the model parameters. 4. Optimizer model (onnx.ModelProto): Model containing the optimizer graph. + All generated ModelProtos will use the same opsets defined by *model*. + Args: model: The base model to be used for gradient graph generation. requires_grad: List of names of model parameters that require gradient computation @@ -207,11 +209,17 @@ def _export_to_ort_format(model_path, output_dir, extra_options): logging.info("Optimizer enum provided: %s", optimizer.name) + opset_version = None + for domain in model.opset_import: + if domain.domain == "" or domain.domain == "ai.onnx": + opset_version = domain.version + break + optim_model = None optim_blocks = {OptimType.AdamW: onnxblock.optim.AdamW, OptimType.SGD: onnxblock.optim.SGD} optim_block = optim_blocks[optimizer]() - with onnxblock.empty_base(): + with onnxblock.empty_base(opset_version=opset_version): _ = optim_block(model_params) optim_model = optim_block.to_model_proto() diff --git a/orttraining/orttraining/python/training/checkpoint.py b/orttraining/orttraining/python/training/checkpoint.py deleted file mode 100644 index d0ff0650662b7..0000000000000 --- a/orttraining/orttraining/python/training/checkpoint.py +++ /dev/null @@ -1,748 +0,0 @@ -import os -import tempfile -import warnings -from enum import Enum - -import numpy as np -import onnx -import torch - -from . import _checkpoint_storage, _utils - -################################################################################ -# Experimental Checkpoint APIs -################################################################################ - - -def experimental_state_dict(ort_trainer, include_optimizer_state=True): - warnings.warn( - "experimental_state_dict() will be deprecated soon. Please use ORTTrainer.state_dict() instead.", - DeprecationWarning, - ) - - if not ort_trainer._training_session: - warnings.warn( - "ONNX Runtime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling state_dict()." - ) - return ort_trainer._state_dict - - # extract trained weights - session_state = ort_trainer._training_session.get_state() - torch_state = {} - for name in session_state: - torch_state[name] = torch.from_numpy(session_state[name]) - - # extract untrained weights and buffer - for n in ort_trainer._onnx_model.graph.initializer: - if n.name not in torch_state and n.name in ort_trainer.options.utils.frozen_weights: - torch_state[n.name] = torch.from_numpy(np.array(onnx.numpy_helper.to_array(n))) - - # Need to remove redundant (optimizer) initializers to map back to original torch state names - if not include_optimizer_state and ort_trainer._torch_state_dict_keys: - return {key: torch_state[key] for key in ort_trainer._torch_state_dict_keys if key in torch_state} - return torch_state - - -def experimental_load_state_dict(ort_trainer, state_dict, strict=False): - warnings.warn( - "experimental_load_state_dict() will be deprecated soon. Please use ORTTrainer.load_state_dict() instead.", - DeprecationWarning, - ) - - # Note: It may happen ONNX model has not yet been initialized - # In this case we cache a reference to desired state and delay the restore until after initialization - # Unexpected behavior will result if the user changes the reference before initialization - if not ort_trainer._training_session: - ort_trainer._state_dict = state_dict - ort_trainer._load_state_dict_strict = strict - return - - # Update onnx model from loaded state dict - cur_initializers_names = [n.name for n in ort_trainer._onnx_model.graph.initializer] - new_initializers = {} - - for name in state_dict: - if name in cur_initializers_names: - new_initializers[name] = state_dict[name].numpy() - elif strict: - raise RuntimeError(f"Checkpoint tensor: {name} is not present in the model.") - - ort_trainer._update_onnx_model_initializers(new_initializers) - - # create new session based on updated onnx model - ort_trainer._state_dict = None - ort_trainer._init_session() - - # load training state - session_state = {name: state_dict[name].numpy() for name in state_dict} - ort_trainer._training_session.load_state(session_state, strict) - - -def experimental_save_checkpoint( - ort_trainer, - checkpoint_dir, - checkpoint_prefix="ORT_checkpoint", - checkpoint_state_dict=None, - include_optimizer_state=True, -): - warnings.warn( - "experimental_save_checkpoint() will be deprecated soon. Please use ORTTrainer.save_checkpoint() instead.", - DeprecationWarning, - ) - - if checkpoint_state_dict is None: - checkpoint_state_dict = {"model": experimental_state_dict(ort_trainer, include_optimizer_state)} - else: - checkpoint_state_dict.update({"model": experimental_state_dict(ort_trainer, include_optimizer_state)}) - - assert os.path.exists(checkpoint_dir), f"checkpoint_dir ({checkpoint_dir}) directory doesn't exist" - - checkpoint_name = _get_checkpoint_name( - checkpoint_prefix, - ort_trainer.options.distributed.deepspeed_zero_optimization.stage, - ort_trainer.options.distributed.world_rank, - ort_trainer.options.distributed.world_size, - ) - checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) - if os.path.exists(checkpoint_file): - msg = f"{checkpoint_file} already exists, overwriting." - warnings.warn(msg) - torch.save(checkpoint_state_dict, checkpoint_file) - - -def experimental_load_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False): - warnings.warn( - "experimental_load_checkpoint() will be deprecated soon. Please use ORTTrainer.load_checkpoint() instead.", - DeprecationWarning, - ) - - checkpoint_files = _list_checkpoint_files(checkpoint_dir, checkpoint_prefix) - is_partitioned = False - if len(checkpoint_files) > 1: - msg = ( - f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." - " Attempting to load ZeRO checkpoint." - ) - warnings.warn(msg) - is_partitioned = True - if (not ort_trainer.options.distributed.deepspeed_zero_optimization.stage) and is_partitioned: - return _load_multi_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, strict) - else: - return _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict) - - -class _AGGREGATION_MODE(Enum): # noqa: N801 - Zero = 0 - Megatron = 1 - - -def _order_paths(paths, D_groups, H_groups): - """Reorders the given paths in order of aggregation of ranks for D and H parallellism respectively - and returns the ordered dict""" - - trainer_options_path_tuples = [] - world_rank = _utils.state_dict_trainer_options_world_rank_key() - - for path in paths: - trainer_options_path_tuples.append( - (_checkpoint_storage.load(path, key=_utils.state_dict_trainer_options_key()), path) - ) - - # sort paths according to rank - sorted_paths = [ - path - for _, path in sorted( - trainer_options_path_tuples, key=lambda trainer_options_path_pair: trainer_options_path_pair[0][world_rank] - ) - ] - - ordered_paths = dict() - ordered_paths["D"] = [[sorted_paths[i] for i in D_groups[group_id]] for group_id in range(len(D_groups))] - ordered_paths["H"] = [[sorted_paths[i] for i in H_groups[group_id]] for group_id in range(len(H_groups))] - - return ordered_paths - - -def _add_or_update_sharded_key( - state_key, state_value, state_sub_dict, model_state_key, state_partition_info, sharded_states_original_dims, mode -): - """Add or update the record for the sharded state_key in the state_sub_dict""" - - # record the original dimension for this state - original_dim = _utils.state_dict_original_dimension_key() - sharded_states_original_dims[model_state_key] = state_partition_info[original_dim] - - axis = 0 - if mode == _AGGREGATION_MODE.Megatron and state_partition_info["megatron_row_partition"] == 0: - axis = -1 - - if state_key in state_sub_dict: - # state_dict already contains a record for this state - # since this state is sharded, concatenate the state value to - # the record in the state_dict - state_sub_dict[state_key] = np.concatenate((state_sub_dict[state_key], state_value), axis) - else: - # create a new entry for this state in the state_dict - state_sub_dict[state_key] = state_value - - -def _add_or_validate_unsharded_key(state_key, state_value, state_sub_dict, mismatch_error_string): - """Add or validate the record for the unsharded state_key in the state_sub_dict""" - - if state_key in state_sub_dict: - # state_dict already contains a record for this unsharded state. - # assert that all values are the same for this previously loaded state - assert (state_sub_dict[state_key] == state_value).all(), mismatch_error_string - else: - # create a new entry for this state in the state_sub_dict - state_sub_dict[state_key] = state_value - - -def _aggregate_model_states( - rank_state_dict, sharded_states_original_dims, state_dict, mixed_precision_enabled, mode=_AGGREGATION_MODE.Zero -): - """Aggregates all model states from the rank_state_dict into state_dict""" - - model = _utils.state_dict_model_key() - full_precision = _utils.state_dict_full_precision_key() - partition_info = _utils.state_dict_partition_info_key() - - # if there are no model states in the rank_state_dict, no model aggregation is needed - if model not in rank_state_dict: - return - - if model not in state_dict: - state_dict[model] = {} - - if full_precision not in state_dict[model]: - state_dict[model][full_precision] = {} - - # iterate over all model state keys - for model_state_key, model_state_value in rank_state_dict[model][full_precision].items(): - # ZERO: full precision model states are sharded only when they exist in the partition_info subdict and mixed - # precision training was enabled. for full precision training, full precision model states are not sharded - # MEGATRON : full precision model states are sharded when they exist in the partition_info subdict - if (model_state_key in rank_state_dict[partition_info]) and ( - mode == _AGGREGATION_MODE.Megatron or mixed_precision_enabled - ): - # this model state is sharded - _add_or_update_sharded_key( - model_state_key, - model_state_value, - state_dict[model][full_precision], - model_state_key, - rank_state_dict[partition_info][model_state_key], - sharded_states_original_dims, - mode, - ) - else: - # this model state is not sharded since a record for it does not exist in the partition_info subdict - _add_or_validate_unsharded_key( - model_state_key, - model_state_value, - state_dict[model][full_precision], - f"Value mismatch for model state {model_state_key}", - ) - - -def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode=_AGGREGATION_MODE.Zero): - """Aggregates all optimizer states from the rank_state_dict into state_dict""" - - optimizer = _utils.state_dict_optimizer_key() - partition_info = _utils.state_dict_partition_info_key() - sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys() - - # if there are no optimizer states in the rank_state_dict, no optimizer aggregation is needed - if optimizer not in rank_state_dict: - return - - if optimizer not in state_dict: - state_dict[optimizer] = {} - - # iterate over all optimizer state keys - for model_state_key, optimizer_dict in rank_state_dict[optimizer].items(): - for optimizer_key, optimizer_value in optimizer_dict.items(): - if model_state_key not in state_dict[optimizer]: - state_dict[optimizer][model_state_key] = {} - - if optimizer_key in sharded_optimizer_keys and model_state_key in rank_state_dict[partition_info]: - # this optimizer state is sharded since a record exists in the partition_info subdict - _add_or_update_sharded_key( - optimizer_key, - optimizer_value, - state_dict[optimizer][model_state_key], - model_state_key, - rank_state_dict[partition_info][model_state_key], - sharded_states_original_dims, - mode, - ) - else: - # this optimizer state is not sharded since a record for it does not exist in the partition_info subdict - # or this optimizer key is not one of the sharded optimizer keys - _add_or_validate_unsharded_key( - optimizer_key, - optimizer_value, - state_dict[optimizer][model_state_key], - f"Value mismatch for model state {model_state_key} and optimizer state {optimizer_key}", - ) - - -def _reshape_states(sharded_states_original_dims, state_dict, mixed_precision_enabled): - """Reshape model and optimizer states in the state_dict according to dimensions in sharded_states_original_dims""" - - model = _utils.state_dict_model_key() - full_precision = _utils.state_dict_full_precision_key() - optimizer = _utils.state_dict_optimizer_key() - sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys() - - for sharded_state_key, original_dim in sharded_states_original_dims.items(): - # reshape model states to original_dim only when mixed precision is enabled - if mixed_precision_enabled and (model in state_dict): - state_dict[model][full_precision][sharded_state_key] = state_dict[model][full_precision][ - sharded_state_key - ].reshape(original_dim) - - # reshape optimizer states to original_dim - if optimizer in state_dict: - for optimizer_key, optimizer_value in state_dict[optimizer][sharded_state_key].items(): - if optimizer_key in sharded_optimizer_keys: - state_dict[optimizer][sharded_state_key][optimizer_key] = optimizer_value.reshape(original_dim) - - -def _aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation): - """Extracts trainer options from rank_state_dict and loads them accordingly on state_dict""" - trainer_options = _utils.state_dict_trainer_options_key() - state_dict[trainer_options] = {} - - mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key() - zero_stage = _utils.state_dict_trainer_options_zero_stage_key() - world_rank = _utils.state_dict_trainer_options_world_rank_key() - world_size = _utils.state_dict_trainer_options_world_size_key() - optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key() - D_size = _utils.state_dict_trainer_options_data_parallel_size_key() # noqa: N806 - H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() # noqa: N806 - - state_dict[trainer_options][mixed_precision] = rank_state_dict[trainer_options][mixed_precision] - state_dict[trainer_options][zero_stage] = 0 - state_dict[trainer_options][world_rank] = rank_state_dict[trainer_options][world_rank] if partial_aggregation else 0 - state_dict[trainer_options][world_size] = 1 - state_dict[trainer_options][optimizer_name] = rank_state_dict[trainer_options][optimizer_name] - state_dict[trainer_options][D_size] = 1 - state_dict[trainer_options][H_size] = 1 - - -def _aggregate_megatron_partition_info(rank_state_dict, state_dict): - """Extracts partition_info from rank_state_dict and loads on state_dict for megatron-partitioned weights""" - partition_info = _utils.state_dict_partition_info_key() - if partition_info not in state_dict: - state_dict[partition_info] = {} - - rank_partition_info = rank_state_dict[partition_info] - for model_state_key, partition_info_dict in rank_partition_info.items(): - if model_state_key not in state_dict[partition_info]: - # add partition info only if weight is megatron partitioned - if partition_info_dict["megatron_row_partition"] >= 0: - state_dict[partition_info][model_state_key] = partition_info_dict - - -def _to_pytorch_format(state_dict): - """Convert ORT state dictionary schema (hierarchical structure) to PyTorch state dictionary schema (flat structure)""" - - pytorch_state_dict = {} - for model_state_key, model_state_value in state_dict[_utils.state_dict_model_key()][ - _utils.state_dict_full_precision_key() - ].items(): - # convert numpy array to a torch tensor - pytorch_state_dict[model_state_key] = torch.tensor(model_state_value) - return pytorch_state_dict - - -def _get_parallellism_groups(data_parallel_size, horizontal_parallel_size, world_size): - """Returns the D and H groups for the given sizes""" - num_data_groups = world_size // data_parallel_size - data_groups = [] - for data_group_id in range(num_data_groups): - data_group_ranks = [] - for r in range(data_parallel_size): - data_group_ranks.append(data_group_id + horizontal_parallel_size * r) - data_groups.append(data_group_ranks) - - num_horizontal_groups = world_size // horizontal_parallel_size - horizontal_groups = [] - for hori_group_id in range(num_horizontal_groups): - hori_group_ranks = [] - for r in range(horizontal_parallel_size): - hori_group_ranks.append(hori_group_id * horizontal_parallel_size + r) - horizontal_groups.append(hori_group_ranks) - - return data_groups, horizontal_groups - - -def _aggregate_over_ranks( - ordered_paths, - ranks, - sharded_states_original_dims=None, - mode=_AGGREGATION_MODE.Zero, - partial_aggregation=False, - pytorch_format=True, -): - """Aggregate checkpoint files over set of ranks and return a single state dictionary - - Args: - ordered_paths: list of paths in the order in which they must be aggregated - ranks: list of ranks that are to be aggregated - sharded_states_original_dims: dict containing the original dims for sharded states that are persisted over - multiple calls to _aggregate_over_ranks() - mode: mode of aggregation: Zero or Megatron - partial_aggregation: boolean flag to indicate whether to produce a partially - aggregated state which can be further aggregated over - pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict - Returns: - state_dict that can be loaded into an ORTTrainer or into a PyTorch model - """ - state_dict = {} - if sharded_states_original_dims is None: - sharded_states_original_dims = dict() - world_rank = _utils.state_dict_trainer_options_world_rank_key() - mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key() - zero_stage = _utils.state_dict_trainer_options_zero_stage_key() - world_size = _utils.state_dict_trainer_options_world_size_key() - optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key() - - loaded_mixed_precision = None - loaded_world_size = None - loaded_zero_stage = None - loaded_optimizer_name = None - - for i, path in enumerate(ordered_paths): - rank_state_dict = _checkpoint_storage.load(path) - - assert _utils.state_dict_partition_info_key() in rank_state_dict, "Missing information: partition_info" - assert _utils.state_dict_trainer_options_key() in rank_state_dict, "Missing information: trainer_options" - assert ( - ranks[i] == rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank] - ), "Unexpected rank in file at path {}. Expected {}, got {}".format( - path, rank, rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank] # noqa: F821 - ) - if loaded_mixed_precision is None: - loaded_mixed_precision = rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] - else: - assert ( - loaded_mixed_precision == rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] - ), f"Mixed precision state mismatch among checkpoint files. File: {path}" - if loaded_world_size is None: - loaded_world_size = rank_state_dict[_utils.state_dict_trainer_options_key()][world_size] - else: - assert ( - loaded_world_size == rank_state_dict[_utils.state_dict_trainer_options_key()][world_size] - ), f"World size state mismatch among checkpoint files. File: {path}" - if loaded_zero_stage is None: - loaded_zero_stage = rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage] - else: - assert ( - loaded_zero_stage == rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage] - ), f"Zero stage mismatch among checkpoint files. File: {path}" - if loaded_optimizer_name is None: - loaded_optimizer_name = rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] - else: - assert ( - loaded_optimizer_name == rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] - ), f"Optimizer name mismatch among checkpoint files. File: {path}" - - # aggregate all model states - _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, loaded_mixed_precision, mode) - - if not pytorch_format: - # aggregate all optimizer states if pytorch_format is False - _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode) - - # for D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups - # to aggregate over Zero, and another pass to aggregate Megatron partitioned - # states. Preserve the relevant partition info only for weights that are megatron partitioned for - # a partial aggregation call - if partial_aggregation: - _aggregate_megatron_partition_info(rank_state_dict, state_dict) - - # entry for trainer_options in the state_dict to perform other sanity checks - if _utils.state_dict_trainer_options_key() not in state_dict: - _aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation) - - # entry for user_dict in the state_dict if not already present - if ( - _utils.state_dict_user_dict_key() not in state_dict - and _utils.state_dict_user_dict_key() in rank_state_dict - ): - state_dict[_utils.state_dict_user_dict_key()] = rank_state_dict[_utils.state_dict_user_dict_key()] - - # for a partial aggregation scenario, we might not have the entire tensor aggregated yet, thus skip reshape - if not partial_aggregation: - # reshape all the sharded tensors based on the original dimensions stored in sharded_states_original_dims - _reshape_states(sharded_states_original_dims, state_dict, loaded_mixed_precision) - - # return a flat structure for PyTorch model in case pytorch_format is True - # else return the hierarchical structure for ORTTrainer - return _to_pytorch_format(state_dict) if pytorch_format else state_dict - - -def _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format): # noqa: N802 - """Aggregate checkpoint files and return a single state dictionary for the D+H - (Zero+Megatron) partitioning strategy. - For D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups - to aggregate over Zero, and another pass over the previously aggregated states - to aggregate Megatron partitioned states. - """ - sharded_states_original_dims = {} - aggregate_data_checkpoint_files = [] - - # combine for Zero over data groups and save to temp file - with tempfile.TemporaryDirectory() as save_dir: - for group_id, d_group in enumerate(D_groups): - aggregate_state_dict = _aggregate_over_ranks( - ordered_paths["D"][group_id], - d_group, - sharded_states_original_dims, - partial_aggregation=True, - pytorch_format=False, - ) - - filename = "ort.data_group." + str(group_id) + ".ort.pt" - filepath = os.path.join(save_dir, filename) - _checkpoint_storage.save(aggregate_state_dict, filepath) - aggregate_data_checkpoint_files.append(filepath) - - assert len(aggregate_data_checkpoint_files) > 0 - - # combine for megatron: - aggregate_state = _aggregate_over_ranks( - aggregate_data_checkpoint_files, - H_groups[0], - sharded_states_original_dims, - mode=_AGGREGATION_MODE.Megatron, - pytorch_format=pytorch_format, - ) - - return aggregate_state - - -def aggregate_checkpoints(paths, pytorch_format=True): - """Aggregate checkpoint files and return a single state dictionary - - Aggregates checkpoint files specified by paths and loads them one at a time, merging - them into a single state dictionary. - The checkpoint files represented by paths must be saved through ORTTrainer.save_checkpoint() function. - The schema of the state_dict returned will be in the same as the one returned by ORTTrainer.state_dict() - - Args: - paths: list of more than one file represented as strings where the checkpoint is saved - pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict - Returns: - state_dict that can be loaded into an ORTTrainer or into a PyTorch model - """ - - loaded_trainer_options = _checkpoint_storage.load(paths[0], key=_utils.state_dict_trainer_options_key()) - D_size = _utils.state_dict_trainer_options_data_parallel_size_key() # noqa: N806 - H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() # noqa: N806 - world_size = _utils.state_dict_trainer_options_world_size_key() - - D_size = loaded_trainer_options[D_size] # noqa: N806 - H_size = loaded_trainer_options[H_size] # noqa: N806 - world_size = loaded_trainer_options[world_size] - D_groups, H_groups = _get_parallellism_groups(D_size, H_size, world_size) # noqa: N806 - - combine_zero = loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 - combine_megatron = len(H_groups[0]) > 1 - - # order the paths in the order of groups in which they must be aggregated according to - # data-parallel groups and H-parallel groups obtained - # eg: {'D': [[path_0, path_2],[path_1, path_3]], 'H': [[path_0, path_1],[path_2, path_3]]} - ordered_paths = _order_paths(paths, D_groups, H_groups) - - aggregate_state = None - if combine_zero and combine_megatron: - aggregate_state = _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format) - elif combine_zero: - aggregate_state = _aggregate_over_ranks( - ordered_paths["D"][0], D_groups[0], mode=_AGGREGATION_MODE.Zero, pytorch_format=pytorch_format - ) - elif combine_megatron: - aggregate_state = _aggregate_over_ranks( - ordered_paths["H"][0], H_groups[0], mode=_AGGREGATION_MODE.Megatron, pytorch_format=pytorch_format - ) - - return aggregate_state - - -################################################################################ -# Helper functions -################################################################################ - - -def _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict): - checkpoint_name = _get_checkpoint_name( - checkpoint_prefix, - is_partitioned, - ort_trainer.options.distributed.world_rank, - ort_trainer.options.distributed.world_size, - ) - checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) - - if is_partitioned: - assert_msg = ( - f"Couldn't find checkpoint file {checkpoint_file}." - " Optimizer partitioning is enabled using ZeRO. Please make sure the checkpoint file exists " - f"for rank {ort_trainer.options.distributed.world_rank} of {ort_trainer.options.distributed.world_size}" - ) - else: - assert_msg = f"Couldn't find checkpoint file {checkpoint_file}." - assert os.path.exists(checkpoint_file), assert_msg - - checkpoint_state = torch.load(checkpoint_file, map_location="cpu") - experimental_load_state_dict(ort_trainer, checkpoint_state["model"], strict=strict) - del checkpoint_state["model"] - return checkpoint_state - - -def _load_multi_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, strict): - checkpoint_files = _list_checkpoint_files(checkpoint_dir, checkpoint_prefix) - - ckpt_agg = _CombineZeroCheckpoint(checkpoint_files) - aggregate_state_dict = ckpt_agg.aggregate_checkpoints() - - experimental_load_state_dict(ort_trainer, aggregate_state_dict, strict=strict) - - # aggregate other keys in the state_dict. - # Values will be overwritten for matching keys among workers - all_checkpoint_states = dict() - for checkpoint_file in checkpoint_files: - checkpoint_state = torch.load(checkpoint_file, map_location="cpu") - del checkpoint_state["model"] - all_checkpoint_states.update(checkpoint_state) - return all_checkpoint_states - - -def _list_checkpoint_files(checkpoint_dir, checkpoint_prefix, extension=".ort.pt"): - ckpt_file_names = [f for f in os.listdir(checkpoint_dir) if f.startswith(checkpoint_prefix)] - ckpt_file_names = [f for f in ckpt_file_names if f.endswith(extension)] - ckpt_file_names = [os.path.join(checkpoint_dir, f) for f in ckpt_file_names] - - assert len(ckpt_file_names) > 0, f"No checkpoint found with prefix '{checkpoint_prefix}' at '{checkpoint_dir}'" - return ckpt_file_names - - -def _get_checkpoint_name(prefix, is_partitioned, world_rank=None, world_size=None): - SINGLE_CHECKPOINT_FILENAME = "{prefix}.ort.pt" # noqa: N806 - MULTIPLE_CHECKPOINT_FILENAME = "{prefix}.ZeRO.{world_rank}.{world_size}.ort.pt" # noqa: N806 - - if is_partitioned: - filename = MULTIPLE_CHECKPOINT_FILENAME.format( - prefix=prefix, world_rank=world_rank, world_size=(world_size - 1) - ) - else: - filename = SINGLE_CHECKPOINT_FILENAME.format(prefix=prefix) - return filename - - -def _split_state_dict(state_dict): - optimizer_keys = ["Moment_1_", "Moment_2_", "Update_Count_", "Step"] - split_sd = {"optimizer": {}, "fp32_param": {}, "fp16_param": {}} - for k, v in state_dict.items(): - mode = "fp32_param" - for optim_key in optimizer_keys: - if k.startswith(optim_key): - mode = "optimizer" - break - if k.endswith("_fp16"): - mode = "fp16_param" - split_sd[mode][k] = v - return split_sd - - -class _CombineZeroCheckpoint: - def __init__(self, checkpoint_files, clean_state_dict=None): - assert len(checkpoint_files) > 0, "No checkpoint files passed" - self.checkpoint_files = checkpoint_files - self.clean_state_dict = clean_state_dict - self.world_size = int(self.checkpoint_files[0].split("ZeRO")[1].split(".")[2]) + 1 - assert len(self.checkpoint_files) == self.world_size, f"Could not find {self.world_size} files" - self.weight_shape_map = {} - self.sharded_params = set() - - def _split_name(self, name: str): - name_split = name.split("_view_") - view_num = None - if len(name_split) > 1: - view_num = int(name_split[1]) - optimizer_key = "" - mp_suffix = "" - if name_split[0].startswith("Moment_1"): - optimizer_key = "Moment_1_" - elif name_split[0].startswith("Moment_2"): - optimizer_key = "Moment_2_" - elif name_split[0].startswith("Update_Count"): - optimizer_key = "Update_Count_" - elif name_split[0].endswith("_fp16"): - mp_suffix = "_fp16" - param_name = name_split[0] - if optimizer_key: - param_name = param_name.split(optimizer_key)[1] - param_name = param_name.split("_fp16")[0] - return param_name, optimizer_key, view_num, mp_suffix - - def _update_weight_statistics(self, name, value): - if name not in self.weight_shape_map: - self.weight_shape_map[name] = value.size() # original shape of tensor - - def _reshape_tensor(self, key): - value = self.aggregate_state_dict[key] - weight_name, _, _, _ = self._split_name(key) - set_size = self.weight_shape_map[weight_name] - self.aggregate_state_dict[key] = value.reshape(set_size) - - def _aggregate(self, param_dict): - for k, v in param_dict.items(): - weight_name, optimizer_key, view_num, mp_suffix = self._split_name(k) - if view_num is not None: - # parameter is sharded - param_name = optimizer_key + weight_name + mp_suffix - - if param_name in self.aggregate_state_dict and optimizer_key not in ["Update_Count_"]: - self.sharded_params.add(param_name) - # Found a previous shard of the param, concatenate shards ordered by ranks - self.aggregate_state_dict[param_name] = torch.cat((self.aggregate_state_dict[param_name], v)) - else: - self.aggregate_state_dict[param_name] = v - else: - if k in self.aggregate_state_dict: - assert (self.aggregate_state_dict[k] == v).all(), "Unsharded params must have the same value" - else: - self.aggregate_state_dict[k] = v - self._update_weight_statistics(weight_name, v) - - def aggregate_checkpoints(self): - warnings.warn( - "_CombineZeroCheckpoint.aggregate_checkpoints() will be deprecated soon. " - "Please use aggregate_checkpoints() instead.", - DeprecationWarning, - ) - - checkpoint_prefix = self.checkpoint_files[0].split(".ZeRO")[0] - self.aggregate_state_dict = dict() - - for i in range(self.world_size): - checkpoint_name = _get_checkpoint_name(checkpoint_prefix, True, i, self.world_size) - rank_state_dict = torch.load(checkpoint_name, map_location=torch.device("cpu")) - if "model" in rank_state_dict: - rank_state_dict = rank_state_dict["model"] - - if self.clean_state_dict: - rank_state_dict = self.clean_state_dict(rank_state_dict) - - rank_state_dict = _split_state_dict(rank_state_dict) - self._aggregate(rank_state_dict["fp16_param"]) - self._aggregate(rank_state_dict["fp32_param"]) - self._aggregate(rank_state_dict["optimizer"]) - - for k in self.sharded_params: - self._reshape_tensor(k) - return self.aggregate_state_dict diff --git a/orttraining/orttraining/python/training/model_desc_validation.py b/orttraining/orttraining/python/training/model_desc_validation.py deleted file mode 100644 index dd3f4cb95cd59..0000000000000 --- a/orttraining/orttraining/python/training/model_desc_validation.py +++ /dev/null @@ -1,408 +0,0 @@ -from collections import namedtuple - -import cerberus -import torch - -from ._utils import static_vars - -LEARNING_RATE_IO_DESCRIPTION_NAME = "__learning_rate" -ALL_FINITE_IO_DESCRIPTION_NAME = "__all_finite" -LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME = "__loss_scale_input_name" -GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME = "__gradient_accumulation_name" - - -class _ORTTrainerModelDesc: - def __init__(self, model_desc): - # Keep a copy of original input for debug - self._original = dict(model_desc) - - # Global counter used to validate occurrences of 'is_loss=True' whithin 'model_desc.outputs' - # A stateless validator is used for each tuple, but validation accross the whole list of tuple is needed - # because just one 'is_loss=True' is allowed withing 'model_desc.outputs' list of tuples - _model_desc_outputs_validation.loss_counter = 0 - - # Used for logging purposes - self._main_class_name = self.__class__.__name__ - - # Validates user input - self._validated = dict(self._original) - validator = cerberus.Validator(MODEL_DESC_SCHEMA) - self._validated = validator.validated(self._validated) - if self._validated is None: - raise ValueError(f"Invalid model_desc: {validator.errors}") - - # Normalize inputs to a list of namedtuple(name, shape) - self._InputDescription = namedtuple("InputDescription", ["name", "shape"]) - self._InputDescriptionTyped = namedtuple("InputDescriptionTyped", ["name", "shape", "dtype"]) - for idx, input in enumerate(self._validated["inputs"]): - self._validated["inputs"][idx] = self._InputDescription(*input) - - # Normalize outputs to a list of namedtuple(name, shape, is_loss) - self._OutputDescription = namedtuple("OutputDescription", ["name", "shape", "is_loss"]) - self._OutputDescriptionTyped = namedtuple( - "OutputDescriptionTyped", ["name", "shape", "is_loss", "dtype", "dtype_amp"] - ) - for idx, output in enumerate(self._validated["outputs"]): - if len(output) == 2: - self._validated["outputs"][idx] = self._OutputDescription(*output, False) - else: - self._validated["outputs"][idx] = self._OutputDescription(*output) - - # Hard-code learning rate, all_finite descriptors - self.learning_rate = self._InputDescriptionTyped(LEARNING_RATE_IO_DESCRIPTION_NAME, [1], torch.float32) - - # Convert dict in object - for k, v in self._validated.items(): - setattr(self, k, self._wrap(v)) - - def __repr__(self): - """Pretty representation for a model description class""" - - pretty_msg = "Model description:\n" - - # Inputs - inputs = [] - for i_desc in self.inputs: - if isinstance(i_desc, self._InputDescription): - inputs.append(f"(name={i_desc.name}, shape={i_desc.shape})") - elif isinstance(i_desc, self._InputDescriptionTyped): - inputs.append(f"(name={i_desc.name}, shape={i_desc.shape}, dtype={i_desc.dtype})") - else: - raise ValueError(f"Unexpected type {type(i_desc)} for input description") - - pretty_msg += "\nInputs:" - for idx, item in enumerate(inputs): - pretty_msg += f"\n\t{idx}: {item}" - - # Outputs - outputs = [] - for o_desc in self.outputs: - if isinstance(o_desc, self._OutputDescription): - outputs.append(f"(name={o_desc.name}, shape={o_desc.shape})") - elif isinstance(o_desc, self._OutputDescriptionTyped): - outputs.append( - f"(name={o_desc.name}, shape={o_desc.shape}, dtype={o_desc.dtype}, dtype_amp={o_desc.dtype_amp})" - ) - else: - raise ValueError(f"Unexpected type {type(o_desc)} for output description") - pretty_msg += "\nOutputs:" - for idx, item in enumerate(outputs): - pretty_msg += f"\n\t{idx}: {item}" - - # Learning rate - if self.learning_rate: - pretty_msg += "\nLearning rate: " - pretty_msg += ( - f"(name={self.learning_rate.name}, shape={self.learning_rate.shape}, dtype={self.learning_rate.dtype})" - ) - - # Mixed precision - if getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None) or getattr( - self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None - ): - pretty_msg += "\nMixed Precision:" - if getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None): - pretty_msg += "\n\tis gradients finite: " - pretty_msg += ( - f"(name={self.all_finite.name}, shape={self.all_finite.shape}, dtype={self.all_finite.dtype})" - ) - if getattr(self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None): - pretty_msg += "\n\tloss scale input name: " - pretty_msg += f"(name={self.loss_scale_input.name}, shape={self.loss_scale_input.shape}, dtype={self.loss_scale_input.dtype})" - - # Gradient Accumulation steps - if self.gradient_accumulation: - pretty_msg += "\nGradient Accumulation: " - pretty_msg += f"(name={self.gradient_accumulation.name}, shape={self.gradient_accumulation.shape}, dtype={self.gradient_accumulation.dtype})" - - return pretty_msg - - def add_type_to_input_description(self, index, dtype): - """Updates an existing input description at position 'index' with 'dtype' type information - - Args: - index (int): position within 'inputs' description - dtype (torch.dtype): input data type - """ - - assert isinstance(index, int) and index >= 0, "input 'index' must be a positive int" - assert isinstance(dtype, torch.dtype), "input 'dtype' must be a torch.dtype type" - existing_values = (*self.inputs[index],) - if isinstance(self.inputs[index], self._InputDescriptionTyped): - existing_values = (*existing_values[:-1],) - self.inputs[index] = self._InputDescriptionTyped(*existing_values, dtype) - - def add_type_to_output_description(self, index, dtype, dtype_amp=None): - """Updates an existing output description at position 'index' with 'dtype' type information - - Args: - index (int): position within 'inputs' description - dtype (torch.dtype): input data type - dtype_amp (torch.dtype, default is None): input data type for evaluation with mixed precision - """ - - assert isinstance(index, int) and index >= 0, "output 'index' must be a positive int" - assert isinstance(dtype, torch.dtype), "output 'dtype' must be a torch.dtype type" - assert dtype_amp is None or isinstance( - dtype_amp, torch.dtype - ), "output 'dtype_amp' must be either None or torch.dtype type" - existing_values = (*self.outputs[index],) - if isinstance(self.outputs[index], self._OutputDescriptionTyped): - existing_values = (*existing_values[:-2],) - self.outputs[index] = self._OutputDescriptionTyped(*existing_values, dtype, dtype_amp) - - @property - def gradient_accumulation(self): - return getattr(self, GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME, None) - - @gradient_accumulation.setter - def gradient_accumulation(self, name): - self._add_output_description( - self, name, [1], False, torch.bool, None, GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME, ignore_duplicate=True - ) - - @property - def all_finite(self): - return getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None) - - @all_finite.setter - def all_finite(self, name): - self._add_output_description( - self, name, [1], False, torch.bool, None, ALL_FINITE_IO_DESCRIPTION_NAME, ignore_duplicate=True - ) - - @property - def loss_scale_input(self): - return getattr(self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None) - - @loss_scale_input.setter - def loss_scale_input(self, name): - self._add_input_description( - self, name, [], torch.float32, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, ignore_duplicate=True - ) - - def _add_input_description(self, node, name, shape, dtype=None, attr_name=None, ignore_duplicate=False): - """Add a new input description into the node object - - If 'dtype' is specified, a typed input description namedtuple(name, shape, dtype) is created. - Otherwise an untyped input description namedtuple(name, shape) is created instead. - - Args: - node (list or object): node to append input description to. When 'node' is 'self.inputs', - a new input description is appended to the list. - Otherwise, a new input description is created as an attribute into 'node' with name 'attr_name' - name (str): name of input description - shape (list): shape of input description - dtype (torch.dtype): input data type - attr_name (str, default is None): friendly name to allow direct access to the output description - ignore_duplicate (bool, default is False): silently skips addition of duplicate inputs - """ - - assert isinstance(name, str) and len(name) > 0, "'name' is an invalid input name" - not_found = True - if not ignore_duplicate: - if id(node) == id(self.inputs): - not_found = all([name not in i_desc.name for i_desc in node]) - assert not_found, f"'name' {name} already exists in the inputs description" - else: - not_found = attr_name not in dir(self) - assert not_found, f"'attr_name' {attr_name} already exists in the 'node'" - elif not not_found: - return - assert isinstance(shape, list) and all( - [(isinstance(dim, int) or (isinstance(dim, str) and len(dim) > 0)) for dim in shape] - ), "'shape' must be a list of int or str with length at least 1" - assert dtype is None or isinstance(dtype, torch.dtype), "'dtype' must be either None or a torch.dtype type" - if dtype: - new_input_desc = self._InputDescriptionTyped(name, shape, dtype) - else: - new_input_desc = self._InputDescription(name, shape) - - if id(node) == id(self.inputs): - self.inputs.append(new_input_desc) - else: - assert isinstance(attr_name, str) and len(attr_name) > 0, "Invalid 'attr_name'" - setattr(node, attr_name, new_input_desc) - - def _add_output_description( - self, node, name, shape, is_loss, dtype=None, dtype_amp=None, attr_name=None, ignore_duplicate=False - ): - """Add a new output description into the node object as a tuple - - When (name, shape, is_loss, dtype) is specified, a typed output description is created - Otherwise an untyped output description (name, shape, is_loss) is created instead - - Args: - node (list or object): node to append output description to. When 'node' is 'self.outputs', - a new output description is appended to the list. - Otherwise, a new output description is created as an attribute into 'node' with name 'attr_name' - name (str): name of output description - shape (list): shape of output description - is_loss (bool): specifies whether this output is a loss - dtype (torch.dtype): input data type - dtype_amp (torch.dtype, default is None): input data type for evaluation with mixed precision. - attr_name (str, default is None): friendly name to allow direct access to the output description - ignore_duplicate (bool, default is False): silently skips addition of duplicate outputs - """ - - assert isinstance(name, str) and len(name) > 0, "'name' is an invalid output name" - assert isinstance(shape, list) and all( - [(isinstance(dim, int) or (isinstance(dim, str) and len(dim) > 0)) for dim in shape] - ), "'shape' must be a list of int or str with length at least 1" - assert isinstance(is_loss, bool), "'is_loss' must be a bool" - - not_found = True - if not ignore_duplicate: - if id(node) == id(self.outputs): - not_found = all([name not in o_desc.name for o_desc in node]) - assert not_found, f"'name' {name} already exists in the outputs description" - assert ( - all([not o_desc.is_loss for o_desc in node]) if is_loss else True - ), "Only one 'is_loss' is supported at outputs description" - else: - not_found = attr_name not in dir(self) - assert not_found, f"'attr_name' {attr_name} already exists in the 'node'" - elif not not_found: - return - - assert dtype is None or isinstance(dtype, torch.dtype), "'dtype' must be either None or a torch.dtype type" - if dtype: - new_output_desc = self._OutputDescriptionTyped(name, shape, is_loss, dtype, None) - else: - new_output_desc = self._OutputDescription(name, shape, is_loss) - - if id(node) == id(self.outputs): - self.outputs.append(new_output_desc) - else: - assert isinstance(attr_name, str) and len(attr_name) > 0, "Invalid 'attr_name'" - setattr(node, attr_name, new_output_desc) - - def _wrap(self, v): - """Add 'v' as self's attribute to allow direct access as self.v""" - if isinstance(v, (list)): - return type(v)([self._wrap(v) for v in v]) - elif isinstance( - v, - ( - self._InputDescription, - self._InputDescriptionTyped, - self._OutputDescription, - self._OutputDescriptionTyped, - ), - ): - return v - elif isinstance(v, (tuple)): - return type(v)([self._wrap(v) for v in v]) - elif isinstance(v, (dict, int, float, bool, str)): - return _ORTTrainerModelDescInternal(self._main_class_name, v) if isinstance(v, dict) else v - else: - raise ValueError( - f"Unsupported type for model_desc ({v})." - "Only int, float, bool, str, list, tuple and dict are supported" - ) - - -class _ORTTrainerModelDescInternal(_ORTTrainerModelDesc): - r"""Internal class used by ONNX Runtime training backend for input validation - - NOTE: Users MUST NOT use this class in any way! - """ - - def __init__(self, main_class_name, model_desc): - # Used for logging purposes - self._main_class_name = main_class_name - - # Convert dict in object - for k, v in dict(model_desc).items(): - setattr(self, k, self._wrap(v)) - - -def _model_desc_inputs_validation(field, value, error): - r"""Cerberus custom check method for 'model_desc.inputs' - - 'model_desc.inputs' is a list of tuples. - The list has variable length, but each tuple has size 2 - - The first element of the tuple is a string which represents the input name - The second element is a list of shapes. Each shape must be either an int or string. - Empty list represents a scalar output - - Validation is done within each tuple to enforce the schema described above. - - Example: - - .. code-block:: python - - model_desc['inputs'] = [('input1', ['batch', 1024]), - ('input2', []) - ('input3', [512])] - """ - - if not isinstance(value, tuple) or len(value) != 2: - error(field, "must be a tuple with size 2") - if not isinstance(value[0], str): - error(field, "the first element of the tuple (aka name) must be a string") - if not isinstance(value[1], list): - error(field, "the second element of the tuple (aka shape) must be a list") - else: - for shape in value[1]: - if not isinstance(shape, str) and not isinstance(shape, int) or isinstance(shape, bool): - error(field, "each shape must be either a string or integer") - - -@static_vars(loss_counter=0) -def _model_desc_outputs_validation(field, value, error): - r"""Cerberus custom check method for 'model_desc.outputs' - - 'model_desc.outputs' is a list of tuples with variable length. - The first element of the tuple is a string which represents the output name - The second element is a list of shapes. Each shape must be either an int or string. - Empty list represents a scalar output - The third element is optional and is a flag that signals whether the output is a loss value - - Validation is done within each tuple to enforce the schema described above, but also - throughout the list of tuples to ensure a single 'is_loss=True' occurrence. - - Example: - - .. code-block:: python - - model_desc['outputs'] = [('output1', ['batch', 1024], is_loss=True), - ('output2', [], is_loss=False) - ('output3', [512])] - """ - - if not isinstance(value, tuple) or len(value) < 2 or len(value) > 3: - error(field, "must be a tuple with size 2 or 3") - if len(value) == 3 and not isinstance(value[2], bool): - error(field, "the third element of the tuple (aka is_loss) must be a boolean") - elif len(value) == 3: - if value[2]: - _model_desc_outputs_validation.loss_counter += 1 - if _model_desc_outputs_validation.loss_counter > 1: - error(field, "only one is_loss can bet set to True") - if not isinstance(value[0], str): - error(field, "the first element of the tuple (aka name) must be a string") - if not isinstance(value[1], list): - error(field, "the second element of the tuple (aka shape) must be a list") - else: - for shape in value[1]: - if not isinstance(shape, str) and not isinstance(shape, int) or isinstance(shape, bool): - error(field, "each shape must be either a string or integer") - - -# Validation schema for model description dictionary -MODEL_DESC_SCHEMA = { - "inputs": { - "type": "list", - "required": True, - "minlength": 1, - "schema": {"check_with": _model_desc_inputs_validation}, - }, - "outputs": { - "type": "list", - "required": True, - "minlength": 1, - "schema": {"check_with": _model_desc_outputs_validation}, - }, -} diff --git a/orttraining/orttraining/python/training/ort_triton/_codegen.py b/orttraining/orttraining/python/training/ort_triton/_codegen.py index 0bf402b750115..462491365c1fa 100644 --- a/orttraining/orttraining/python/training/ort_triton/_codegen.py +++ b/orttraining/orttraining/python/training/ort_triton/_codegen.py @@ -45,17 +45,15 @@ class TritonCodegen(NodeVisitor): Specialized codegen for Triton backend. """ - def __init__(self): - super().__init__() - def codegen(self, node: IRNode, context: CodegenContext, code_buffer: CodeBuffer, indent: int): func = getattr(self, node.__class__.__name__) - assert func is not None, "unimplemented node: %s" % node.__class__.__name__ + assert func is not None, f"unimplemented node: {node.__class__.__name__}" func(node, context, code_buffer, indent) def _get_elementwise_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str) -> Tuple[str, str]: if offset_calc.is_x_reduced(arg_name): - return "", "" + # Scalar. + return "tl.full([1], 0, tl.int32)", "" if offset_calc.is_same_x_shape(arg_name): return "xindex", "xmask" if offset_calc.requires_x_mask else "" strides = offset_calc.get_input_strides(arg_name) @@ -91,13 +89,16 @@ def _get_reduce_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str) if offset_calc.requires_r_mask: mask_strs.append("rmask") + # If both is_x_reduced and is_r_reduced are True, it's scalar. + if len(offset_strs) == 0: + offset_strs.append("tl.full([1, 1], 0, tl.int32)") return " + ".join(offset_strs), " & ".join(mask_strs) - def _get_offset_mask(self, node: OffsetCalculator, arg_name: str) -> Tuple[str, str]: + def _get_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str) -> Tuple[str, str]: return ( - self._get_reduce_offset_mask(node, arg_name) - if node.is_reduction - else self._get_elementwise_offset_mask(node, arg_name) + self._get_reduce_offset_mask(offset_calc, arg_name) + if offset_calc.is_reduction + else self._get_elementwise_offset_mask(offset_calc, arg_name) ) def IONode(self, node: IONode, context: CodegenContext, code_buffer: CodeBuffer, indent: int): # noqa: N802 @@ -125,18 +126,29 @@ def IONode(self, node: IONode, context: CodegenContext, code_buffer: CodeBuffer, def _gen_kernel_signature(self, node: KernelNode, context: CodegenContext, code_buffer: CodeBuffer, indent: int): is_reduction = node.offset_calc.is_reduction space_indent = " " * indent - autotune_configs_str = "" - for config in node.offset_calc.autotune_configs.configs: - if is_reduction: - autotune_configs_str += ( - f'{space_indent} triton.Config({{"XBLOCK": {config[0]}, "RBLOCK": {config[1]}}}, ' - f"num_warps={config[2]}),\n" - ) - else: - autotune_configs_str += ( - f'{space_indent} triton.Config({{"XBLOCK": {config[0]}}}, num_warps={config[2]}),\n' - ) - keys_str = '"xnumel", "rnumel"' if is_reduction else '"xnumel"' + + if len(node.offset_calc.autotune_configs.configs) > 1: + autotune_configs_str = "" + for config in node.offset_calc.autotune_configs.configs: + if is_reduction: + autotune_configs_str += ( + f'{space_indent} triton.Config({{"XBLOCK": {config[0]}, "RBLOCK": {config[1]}}}, ' + f"num_warps={config[2]}),\n" + ) + else: + autotune_configs_str += ( + f'{space_indent} triton.Config({{"XBLOCK": {config[0]}}}, num_warps={config[2]}),\n' + ) + keys_str = '"xnumel", "rnumel"' if is_reduction else '"xnumel"' + code_buffer += ( + f"{space_indent}@triton.autotune(\n" + f"{space_indent} configs=[\n" + f"{autotune_configs_str}" + f"{space_indent} ],\n" + f"{space_indent} key=[{keys_str}],\n" + f"{space_indent})\n" + ) + input_args = [context.get_variable_name(input.name) for input in node.inputs] input_args_str = ", ".join(input_args) if input_args_str: @@ -158,12 +170,6 @@ def _gen_kernel_signature(self, node: KernelNode, context: CodegenContext, code_ ) code_buffer += ( - f"{space_indent}@triton.autotune(\n" - f"{space_indent} configs=[\n" - f"{autotune_configs_str}" - f"{space_indent} ],\n" - f"{space_indent} key=[{keys_str}],\n" - f"{space_indent})\n" f"{space_indent}@triton.jit\n" f"{space_indent}def {node.name}({input_args_str}{output_args_str}{other_input_args}{blocks_str}):\n" ) @@ -175,8 +181,10 @@ def ElementwiseKernelNode( # noqa: N802 offset_calc = node.offset_calc indent += 4 space_indent = " " * indent + x_numel_str = str(offset_calc.x_numel) + if x_numel_str.isnumeric(): + code_buffer += f"{space_indent}xnumel = {x_numel_str}\n" code_buffer += ( - f"{space_indent}xnumel = {offset_calc.x_numel}\n" f"{space_indent}xoffset = tl.program_id(0) * XBLOCK\n" f"{space_indent}xindex = xoffset + tl.arange(0, XBLOCK)\n" ) @@ -207,9 +215,13 @@ def ReduceKernelNode( # noqa: N802 offset_calc = node.offset_calc indent += 4 space_indent = " " * indent + x_numel_str = str(offset_calc.x_numel) + if x_numel_str.isnumeric(): + code_buffer += f"{space_indent}xnumel = {x_numel_str}\n" + r_numel_str = str(offset_calc.r_numel) + if r_numel_str.isnumeric(): + code_buffer += f"{space_indent}rnumel = {r_numel_str}\n" code_buffer += ( - f"{space_indent}xnumel = {offset_calc.x_numel}\n" - f"{space_indent}rnumel = {offset_calc.r_numel}\n" f"{space_indent}xoffset = tl.program_id(0) * XBLOCK\n" f"{space_indent}xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n" f"{space_indent}rbase = tl.arange(0, RBLOCK)[None, :]\n" @@ -444,6 +456,13 @@ def ModuleNode(self, node: ModuleNode, context: CodegenContext, code_buffer: Cod indent += 4 space_indent = " " * indent + seen_symbolic_shape = set() + for input in node.inputs: + for idx, dim in enumerate(input.shape): + if dim.is_symbol and dim not in seen_symbolic_shape: + code_buffer += f"{space_indent}{dim} = {context.get_variable_name(input.name)}.size()[{idx}]\n" + seen_symbolic_shape.add(dim) + if node.has_dropout: code_buffer += ( f'{space_indent}seed_cuda = torch.randint(2**31, size=(), dtype=torch.int64, device="cuda")\n\n' @@ -470,18 +489,31 @@ def ModuleNode(self, node: ModuleNode, context: CodegenContext, code_buffer: Cod if kernel_node.has_dropout: kernel_args_str += ", seed_cuda" + # Support symbolic shape if any. + symbolic_shape_args_str = ", ".join(kernel_node.symbolic_shape_variables) + if symbolic_shape_args_str: + kernel_args_str += f", {symbolic_shape_args_str}" + + block_str = "" + if len(kernel_node.offset_calc.autotune_configs.configs) == 1: + config = kernel_node.offset_calc.autotune_configs.configs[0] + if kernel_node.offset_calc.is_reduction: + block_str = f", XBLOCK={config[0]}, RBLOCK={config[1]}, num_warps={config[2]}" + else: + block_str = f", XBLOCK={config[0]}, num_warps={config[2]}" + if isinstance(kernel_node, ReduceKernelNode): code_buffer += ( f"{space_indent}x_numel = {kernel_node.offset_calc.x_numel}\n" f"{space_indent}r_numel = {kernel_node.offset_calc.r_numel}\n" f'{space_indent}grid = lambda meta: (triton.cdiv(x_numel, meta["XBLOCK"]),)\n' - f"{space_indent}{kernel_node.name}[grid]({kernel_args_str}, x_numel, r_numel)\n" + f"{space_indent}{kernel_node.name}[grid]({kernel_args_str}, x_numel, r_numel{block_str})\n" ) else: code_buffer += ( f"{space_indent}n_elements = {kernel_node.offset_calc.x_numel}\n" f'{space_indent}grid = lambda meta: (triton.cdiv(n_elements, meta["XBLOCK"]),)\n' - f"{space_indent}{kernel_node.name}[grid]({kernel_args_str}, n_elements)\n" + f"{space_indent}{kernel_node.name}[grid]({kernel_args_str}, n_elements{block_str})\n" ) for name in node.cross_kernel_args_to_delete[idx]: diff --git a/orttraining/orttraining/python/training/ort_triton/_common.py b/orttraining/orttraining/python/training/ort_triton/_common.py index 82ac82cfa2919..b7e55bc733ede 100644 --- a/orttraining/orttraining/python/training/ort_triton/_common.py +++ b/orttraining/orttraining/python/training/ort_triton/_common.py @@ -9,9 +9,11 @@ import sympy from onnx import GraphProto, NodeProto, TensorProto -from ._sympy_utils import parse_shape +from ._sympy_utils import extract_shape_from_symbol from ._utils import get_attribute, get_reduce_info, next_power_of_2 +_SPECIAL_FLOATS: List[str] = ["inf", "-inf"] + class CodegenContext: """ @@ -28,7 +30,8 @@ def get_variable_name(self, name: str) -> str: # For some operators such as data load/store, we need an internal variable name inside the kernel function. def get_internal_variable_name(self, name: str) -> str: var_name = self._var_map[name] - return self._var_map[var_name] if var_name in self._var_map else var_name + var_name = self._var_map[var_name] if var_name in self._var_map else var_name + return f'float("{var_name}")' if var_name in _SPECIAL_FLOATS else var_name class CodeBuffer: @@ -49,14 +52,38 @@ def codegen(self, node: Any, context: CodegenContext, code_buffer: CodeBuffer, i pass +class SymbolicDSU: + """ + A 'disjoint set union' to merge symbolics so that we use less variables in the generated code. + When handling shape inference for elementwise Ops, if two symbols are not equal and they are not 1, we merge them. + """ + + def __init__(self): + self._dsu: Dict[sympy.Expr, sympy.Expr] = {} + + def find(self, symbolic: sympy.Expr) -> sympy.Expr: + if symbolic not in self._dsu: + self._dsu[symbolic] = symbolic + return symbolic + if symbolic == self._dsu[symbolic]: + return symbolic + self._dsu[symbolic] = self.find(self._dsu[symbolic]) + return self._dsu[symbolic] + + def union(self, symbolic: sympy.Expr, other_symbolic: sympy.Expr): + root = self.find(symbolic) + other_root = self.find(other_symbolic) + self._dsu[other_root] = root + + class TensorInfo: """ Represent a input/output tensor of a node. """ - def __init__(self, dtype: TensorProto.DataType, shape: List[Any]): + def __init__(self, dtype: TensorProto.DataType, shape: List[sympy.Expr]): self._dtype: TensorProto.DataType = dtype - self._shape: List[sympy.Expr] = parse_shape(shape) + self._shape: List[sympy.Expr] = shape @property def dtype(self) -> TensorProto.DataType: @@ -66,27 +93,42 @@ def dtype(self) -> TensorProto.DataType: def shape(self) -> List[sympy.Expr]: return self._shape + def update_shape(self, symbolics: SymbolicDSU): + self._shape = [symbolics.find(dim) if dim.is_symbol else dim for dim in self._shape] + -def _infer_elementwise_shape(input_infos: List[TensorInfo]) -> List[sympy.Expr]: +def _infer_elementwise_shape(input_infos: List[TensorInfo], symbolics: SymbolicDSU) -> List[sympy.Expr]: max_len = max([len(input_info.shape) for input_info in input_infos]) output_shape: List[sympy.Expr] = [sympy.Integer(1)] * max_len for input_info in input_infos: offset = max_len - len(input_info.shape) - for i in range(len(input_info.shape)): - if not input_info.shape[i].is_number or input_info.shape[i] != 1: - output_shape[i + offset] = input_info.shape[i] + for idx, dim in enumerate(input_info.shape): + if not dim.is_number or dim != 1: + if not output_shape[idx + offset].is_number or output_shape[idx + offset] != 1: + symbolics.union(output_shape[idx + offset], dim) + else: + output_shape[idx + offset] = dim return output_shape -def _infer_elementwise(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]: - return [TensorInfo(input_infos[0].dtype, _infer_elementwise_shape(input_infos))] +def _infer_elementwise( + node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU +) -> List[TensorInfo]: + # pylint: disable=unused-argument + return [TensorInfo(input_infos[0].dtype, _infer_elementwise_shape(input_infos, symbolics))] -def _infer_where(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]: - return [TensorInfo(input_infos[1].dtype, _infer_elementwise_shape(input_infos))] +def _infer_where( + node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU +) -> List[TensorInfo]: + # pylint: disable=unused-argument + return [TensorInfo(input_infos[1].dtype, _infer_elementwise_shape(input_infos, symbolics))] -def _infer_reduction(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]: +def _infer_reduction( + node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU +) -> List[TensorInfo]: + # pylint: disable=unused-argument input_rank = len(input_infos[0].shape) keep_dims, axes = get_reduce_info(node, graph, input_rank) axes = [axis + input_rank if axis < 0 else axis for axis in axes] @@ -98,17 +140,26 @@ def _infer_reduction(node: NodeProto, input_infos: List[TensorInfo], graph: Grap return [TensorInfo(input_infos[0].dtype, shape)] -def _infer_unary(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]: +def _infer_unary( + node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU +) -> List[TensorInfo]: + # pylint: disable=unused-argument return [input_infos[0]] -def _infer_cast(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]: +def _infer_cast( + node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU +) -> List[TensorInfo]: + # pylint: disable=unused-argument dtype = get_attribute(node, "to", TensorProto.UNDEFINED) assert dtype != TensorProto.UNDEFINED return [TensorInfo(dtype, input_infos[0].shape)] -def _infer_dropout(node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]: +def _infer_dropout( + node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU +) -> List[TensorInfo]: + # pylint: disable=unused-argument return [input_infos[0], TensorInfo(TensorProto.BOOL, input_infos[0].shape)] @@ -138,10 +189,12 @@ class TypeAndShapeInfer: } @classmethod - def infer(cls, node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto) -> List[TensorInfo]: + def infer( + cls, node: NodeProto, input_infos: List[TensorInfo], graph: GraphProto, symbolics: SymbolicDSU + ) -> List[TensorInfo]: if node.op_type not in cls._INFER_FUNC_MAP: raise NotImplementedError(f"Unsupported op type: {node.op_type}") - return cls._INFER_FUNC_MAP[node.op_type](node, input_infos, graph) + return cls._INFER_FUNC_MAP[node.op_type](node, input_infos, graph, symbolics) class AutotuneConfigs: @@ -152,9 +205,30 @@ class AutotuneConfigs: If it's reduction kernel on last contiguous dimensions, the contiguous flag is True. """ - def __init__(self, x_numel: int, r_numel: int, contiguous: bool): - self.configs: List[Tuple[int, int, int]] = self._gen_autotune_configs(x_numel, r_numel, contiguous) - self.requires_for_loop: bool = any(config[1] < r_numel for config in self.configs) + def __init__(self, x_numel: sympy.Expr, r_numel: sympy.Expr, contiguous: bool): + x_numel_int = ( + int(x_numel) + if x_numel.is_number + else int( + x_numel.subs( + {symbol: sympy.Integer(extract_shape_from_symbol(symbol.name)) for symbol in x_numel.free_symbols} + ) + ) + ) + r_numel_int = ( + int(r_numel) + if r_numel.is_number + else int( + r_numel.subs( + {symbol: sympy.Integer(extract_shape_from_symbol(symbol.name)) for symbol in r_numel.free_symbols} + ) + ) + ) + self.configs: List[Tuple[int, int, int]] = self._gen_autotune_configs(x_numel_int, r_numel_int, contiguous) + # If there is symbolic shape, we will not tune the kernel. + if not x_numel.is_number or not r_numel.is_number: + self.configs = self.configs[-1:] + self.requires_for_loop: bool = any(config[1] < r_numel_int for config in self.configs) def _num_warps(self, x: int, r: int) -> int: return min(max(x * r // 256, 2), 8) diff --git a/orttraining/orttraining/python/training/ort_triton/_decompose.py b/orttraining/orttraining/python/training/ort_triton/_decompose.py index e18bb16bb80db..ffd20b09b42ea 100644 --- a/orttraining/orttraining/python/training/ort_triton/_decompose.py +++ b/orttraining/orttraining/python/training/ort_triton/_decompose.py @@ -58,7 +58,7 @@ def _get_dtype_and_shape(self, arg_name: str, **kwargs): arg_info = node_arg_infos[arg_name] return arg_info.dtype, arg_info.shape - def _decompose_elementwise_precision(self, node: NodeProto, graph: GraphProto, **kwargs): + def _decompose_elementwise_precision(self, node: NodeProto, **kwargs): x = node.input[0] dtype, _ = self._get_dtype_and_shape(x, **kwargs) if not _is_half_dtype(dtype): @@ -79,15 +79,19 @@ def _decompose_elementwise_precision(self, node: NodeProto, graph: GraphProto, * return [*cast_nodes, op_node, cast_node1] def Exp(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802 - return self._decompose_elementwise_precision(node, graph, **kwargs) + # pylint: disable=unused-argument + return self._decompose_elementwise_precision(node, **kwargs) def Pow(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802 - return self._decompose_elementwise_precision(node, graph, **kwargs) + # pylint: disable=unused-argument + return self._decompose_elementwise_precision(node, **kwargs) def Sqrt(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802 - return self._decompose_elementwise_precision(node, graph, **kwargs) + # pylint: disable=unused-argument + return self._decompose_elementwise_precision(node, **kwargs) def LayerNormalization(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802 + # pylint: disable=unused-argument node_name = node.name x = node.input[0] w = node.input[1] @@ -153,6 +157,7 @@ def LayerNormalization(self, node: NodeProto, graph: GraphProto, **kwargs): # n ] def LayerNormalizationGrad(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802 + # pylint: disable=unused-argument node_name = node.name dy = node.input[0] x = node.input[1] @@ -241,6 +246,7 @@ def LayerNormalizationGrad(self, node: NodeProto, graph: GraphProto, **kwargs): return decomposed_nodes def Softmax(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802 + # pylint: disable=unused-argument node_name = node.name x = node.input[0] y = node.output[0] @@ -259,6 +265,7 @@ def Softmax(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802 return [max_node, sub_node, exp_node, sum_node, div_node] def SoftmaxGrad_13(self, node: NodeProto, graph: GraphProto, **kwargs): # noqa: N802 + # pylint: disable=unused-argument node_name = node.name dy = node.input[0] y = node.input[1] diff --git a/orttraining/orttraining/python/training/ort_triton/_ir.py b/orttraining/orttraining/python/training/ort_triton/_ir.py index f7d3b31eac5b6..50121cbf49804 100644 --- a/orttraining/orttraining/python/training/ort_triton/_ir.py +++ b/orttraining/orttraining/python/training/ort_triton/_ir.py @@ -88,13 +88,15 @@ def __init__(self, target_shape: List[sympy.Expr], reduce_axes: List[int]): self.r_strides.insert(0, self.r_strides[0] * self.r_dims[i + 1]) self.r_compute_dims: Set[int] = set() self.input_strides: Dict[str, List[sympy.Expr]] = dict() - # Support concrete shape only for now. - assert self.x_numel.is_integer and self.r_numel.is_integer self.autotune_configs: AutotuneConfigs = AutotuneConfigs( - int(self.x_numel), int(self.r_numel), not self.is_reduction or self.reduce_axes[-1] == self.rank - 1 + self.x_numel, self.r_numel, not self.is_reduction or self.reduce_axes[-1] == self.rank - 1 + ) + self.requires_x_mask: bool = not self.x_numel.is_number or any( + int(self.x_numel) % config[0] != 0 for config in self.autotune_configs.configs + ) + self.requires_r_mask: bool = not self.r_numel.is_number or any( + int(self.r_numel) % config[1] != 0 for config in self.autotune_configs.configs ) - self.requires_x_mask: bool = any(int(self.x_numel) % config[0] != 0 for config in self.autotune_configs.configs) - self.requires_r_mask: bool = any(int(self.r_numel) % config[1] != 0 for config in self.autotune_configs.configs) self.reduced_args: Set[str] = set() def get_input_strides(self, name: str) -> List[sympy.Expr]: @@ -129,7 +131,7 @@ def register_tensor_arg(self, tensor_arg: TensorArg): input_shape = tensor_arg.shape if tensor_arg.name in self.reduced_args: assert self.is_reduction - reduced_rank = len(input_shape) - len(self.reduce_axes) + reduced_rank = len(self.target_shape) - len(self.reduce_axes) if len(input_shape) < reduced_rank: input_shape = [sympy.Integer(1)] * (reduced_rank - len(input_shape)) + input_shape input_shape = ( @@ -141,7 +143,9 @@ def register_tensor_arg(self, tensor_arg: TensorArg): input_shape = [sympy.Integer(1)] * (len(self.target_shape) - len(input_shape)) + input_shape running_stride = sympy.Integer(1) for i in range(len(self.target_shape) - 1, -1, -1): - if self.target_shape[i] == input_shape[i]: + if self.target_shape[i] == input_shape[i] and not ( + tensor_arg.name in self.reduced_args and i in self.reduce_axes + ): strides.insert(0, running_stride) running_stride = running_stride * input_shape[i] else: @@ -300,17 +304,20 @@ def gen_variable_names(self): self.var_map[name] = "t_" + name for name in self.internal_args: self.var_map[name] = gen_variable_name(name, "t", existing_names) - for constant_name in self.constants: - self.var_map[constant_name] = gen_variable_name(constant_name, "c", existing_names) - if self.constants[constant_name].data is not None: - value = self.constants[constant_name].data + for name, tensor_arg in self.constants.items(): + self.var_map[name] = gen_variable_name(name, "c", existing_names) + if tensor_arg.data is not None: + value = tensor_arg.data if value is not None: assert value.size == 1, f"unsupported constant array {value}" - variable_name = self.var_map[constant_name] + variable_name = self.var_map[name] assert variable_name not in self.var_map self.var_map[variable_name] = str(np.array(value.item(), value.dtype)) - - self.symbolic_shape_variables = [str(dim) for dim in self.target_shape if dim.is_symbol] + seen = set() + for dim in self.target_shape: + if dim.is_symbol and dim not in seen: + seen.add(dim) + self.symbolic_shape_variables.append(str(dim)) class ElementwiseKernelNode(KernelNode): diff --git a/orttraining/orttraining/python/training/ort_triton/_lowering.py b/orttraining/orttraining/python/training/ort_triton/_lowering.py index 5de60e69437a0..5c848d2cecc58 100644 --- a/orttraining/orttraining/python/training/ort_triton/_lowering.py +++ b/orttraining/orttraining/python/training/ort_triton/_lowering.py @@ -51,10 +51,8 @@ def __init__(self, node: NodeProto, reduce_axes: List[int], keep_dims: int, node # r_numel is meant to hint how many elements in a row of tensor will be processed by each kernel. # r is a abbreviation of reduction, so, it's only used for reduction nodes. r_numel: sympy.Expr = sympy.prod(r_dims) if len(r_dims) > 0 else sympy.Integer(1) - # Support concrete shape only for now. - assert x_numel.is_integer and r_numel.is_integer self.autotune_configs: AutotuneConfigs = AutotuneConfigs( - int(x_numel), int(r_numel), len(self.reduce_axes) == 0 or self.reduce_axes[-1] == rank - 1 + x_numel, r_numel, len(self.reduce_axes) == 0 or self.reduce_axes[-1] == rank - 1 ) self.reduced_args: Set[str] = set() if keep_dims != 1: @@ -69,10 +67,8 @@ def _compatible_shape(self, shape: List[sympy.Expr], split_if_different: bool) - if len(shape) > len(self.target_shape): return False shape = [sympy.Integer(1)] * (len(self.target_shape) - len(shape)) + shape - for axis in range(len(shape)): - if shape[axis] != self.target_shape[axis] and ( - not shape[axis].is_number or shape[axis] != sympy.Integer(1) - ): + for axis, dim in enumerate(shape): + if dim != self.target_shape[axis] and (not dim.is_number or dim != sympy.Integer(1)): return False return True @@ -129,7 +125,7 @@ def has_reduced_elementwise_nodes(self) -> bool: return not is_reduction_node(self.nodes_groups[0]) and len(self.reduced_args) > 0 def dependent_nodes(self, keep_reduce_node: bool): - node_map = dict() + node_map = {} reduce_nodes = [] if not keep_reduce_node and self.has_reduced_elementwise_nodes(): for item in self.nodes_groups: @@ -151,8 +147,8 @@ def flatten(self, sorted_nodes: List[NodeProto]) -> Tuple[List[NodeProto], List[ layers = [] group_layer = [self] while len(group_layer) > 0: - node_map = dict() - reduce_node_map = dict() + node_map = {} + reduce_node_map = {} next_layer = [] for group in group_layer: sub_node_map, reduce_nodes = group.dependent_nodes(False) @@ -201,7 +197,7 @@ def __init__(self): self.cross_kernel_inputs: List[str] = [] self.constants: List[str] = [] self.module_outputs: List[str] = [] - self.cross_kernel_outputs: [str] = [] + self.cross_kernel_outputs: List[str] = [] self.internal_args: List[str] = [] @@ -284,7 +280,7 @@ def _process_node(self, node: NodeProto, precessors: Dict[str, List[NodeProto]], return dependent_nodes def _group_nodes(self): - producers = dict() + producers = {} precessors = defaultdict(list) processed = set() groups = [] @@ -321,13 +317,16 @@ def _group_nodes(self): group_dependencies[k].add(j) flag = set() - for i in range(len(groups)): - if i not in flag: - for j in range(i + 1, len(groups)): - if j not in flag and j not in group_dependencies[i] and groups[i].try_merge(groups[j]): - flag.add(j) - self._groups.append(groups[i]) - flag.add(i) + for i, group_i in enumerate(groups): + if i in flag: + continue + for j, group_j in enumerate(groups): + if j <= i: + continue + if j not in flag and j not in group_dependencies[i] and group_i.try_merge(group_j): + flag.add(j) + self._groups.append(group_i) + flag.add(i) def _get_node_io(self, node: NodeProto) -> Tuple[List[TensorArg], List[TensorArg]]: input_args = [] @@ -395,7 +394,7 @@ def _analyze_kernel_io_list(self): def _insert_load_and_store(self, kernel_node: KernelNode): input_names = [input.name for input in kernel_node.inputs] - output_name_map = dict() + output_name_map = {} for output in kernel_node.outputs: output_name_map[output.name] = 0 for node in kernel_node.sub_nodes: @@ -499,7 +498,7 @@ def _lower(self): warnings.warn("Use triton's random for Dropout, ignore the random seed from ORT.", UserWarning) self._analyze_kernel_io_list() - cross_kernel_arg_map = dict() + cross_kernel_arg_map = {} for idx, kernel_io in enumerate(self._kernel_io_list): for output in itertools.chain(kernel_io.cross_kernel_outputs, kernel_io.module_outputs): cross_kernel_arg_map[output] = idx diff --git a/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py b/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py index 69df567500a89..32e54d0868013 100644 --- a/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py +++ b/orttraining/orttraining/python/training/ort_triton/_sorted_graph.py @@ -5,15 +5,17 @@ import copy import itertools -from typing import Any, Dict, List, Set +from typing import Dict, List, Set import numpy as np import onnx -from onnx import GraphProto, ModelProto, NodeProto, helper +import sympy +from onnx import GraphProto, ModelProto, NodeProto, TensorProto, helper -from ._common import TensorInfo, TypeAndShapeInfer +from ._common import SymbolicDSU, TensorInfo, TypeAndShapeInfer from ._decompose import DecomposeDispatch from ._op_config import is_elementwise_node +from ._sympy_utils import parse_shape from ._utils import get_attribute, to_numpy_array, topological_sort @@ -29,17 +31,20 @@ class SortedGraph: input_shapes: the shapes of the model inputs. Can be numeric values or symbolic values. """ - def __init__(self, model: ModelProto, input_shapes: List[List[Any]]): + def __init__(self, model: ModelProto, input_shapes: List[List[sympy.Expr]]): self._model: ModelProto = model self._graph: GraphProto = model.graph - self._input_shapes: List[List[Any]] = input_shapes + self._input_shapes: List[List[sympy.Expr]] = input_shapes # For elementwise graph outputs, when we group nodes to different kernels, if the target shape is different # from other nodes' target shape, even it can be broadcasted, we still need to create a new kernel for it. self._elementwise_graph_outputs: Set[str] = set() + graph_output_names = [output.name for output in self._graph.output] for node in self._graph.node: if is_elementwise_node(node): - self._elementwise_graph_outputs.update(node.output) + self._elementwise_graph_outputs.update( + [output for output in node.output if output in graph_output_names] + ) # Topological sort the nodes in the graph. self._sorted_nodes: List[NodeProto] = topological_sort( @@ -53,7 +58,7 @@ def __init__(self, model: ModelProto, input_shapes: List[List[Any]]): for initializer in self._graph.initializer: self._node_arg_infos[initializer.name] = TensorInfo( initializer.data_type, - list(to_numpy_array(initializer).shape), + parse_shape(list(to_numpy_array(initializer).shape)), ) # Decompose complex operators. @@ -66,7 +71,7 @@ def __init__(self, model: ModelProto, input_shapes: List[List[Any]]): initializers = {} for initializer in self._graph.initializer: initializers[initializer.name] = initializer - self._sorted_initializers: List[TensorInfo] = [] + self._sorted_initializers: List[TensorProto] = [] for node in self._sorted_nodes: for input in node.input: if input in initializers: @@ -157,6 +162,7 @@ def elementwise_graph_outputs(self) -> Set[str]: def _decompose(self): dispatch = DecomposeDispatch() + symbolics: SymbolicDSU = SymbolicDSU() pos = 0 # If a node is complex, decompose it and insert the decomposed nodes at the same position. # All complex Ops are defined in DecomposeDispatch. @@ -175,16 +181,18 @@ def _decompose(self): value_attr = get_attribute(node, "value") self._node_arg_infos[node.output[0]] = TensorInfo( value_attr.data_type, - list(to_numpy_array(value_attr).shape), + parse_shape(list(to_numpy_array(value_attr).shape)), ) else: input_infos = [] for input in node.input: input_infos.append(self._node_arg_infos[input]) - output_infos = TypeAndShapeInfer.infer(node, input_infos, self._graph) + output_infos = TypeAndShapeInfer.infer(node, input_infos, self._graph, symbolics) for idx, output in enumerate(node.output): self._node_arg_infos[output] = output_infos[idx] pos += 1 + for tensor_info in self._node_arg_infos.values(): + tensor_info.update_shape(symbolics) # Save the ONNX graphs for debug purpose. The original ONNX graph is the subgraph from backend. # The processed ONNX graph is the subgraph after decompose, it also contains the concrete shapes for each arg. @@ -197,13 +205,20 @@ def save_onnx(self, file_path_prefix): for node in itertools.chain(processed_model.graph.input, processed_model.graph.output): node.type.tensor_type.shape.Clear() for dim in self.node_arg_infos[node.name].shape: - node.type.tensor_type.shape.dim.add().dim_value = int(dim) + if dim.is_number: + node.type.tensor_type.shape.dim.add().dim_value = int(dim) + else: + node.type.tensor_type.shape.dim.add().dim_param = str(dim) value_infos = [] for node in itertools.chain(self.const_nodes, self.sorted_nodes): for output in node.output: tensor_info = self.node_arg_infos[output] value_infos.append( - helper.make_tensor_value_info(output, tensor_info.dtype, [int(dim) for dim in tensor_info.shape]) + helper.make_tensor_value_info( + output, + tensor_info.dtype, + [int(dim) if dim.is_number else str(dim) for dim in tensor_info.shape], + ) ) processed_model.graph.ClearField("value_info") processed_model.graph.value_info.extend(value_infos) diff --git a/orttraining/orttraining/python/training/ort_triton/_sympy_utils.py b/orttraining/orttraining/python/training/ort_triton/_sympy_utils.py index e3629b5effa38..a4a384c021fe8 100644 --- a/orttraining/orttraining/python/training/ort_triton/_sympy_utils.py +++ b/orttraining/orttraining/python/training/ort_triton/_sympy_utils.py @@ -9,6 +9,12 @@ import sympy +def extract_shape_from_symbol(symbol: str) -> int: + match = re.match(r"i(\d+)_dim(\d+)_(\d+)", symbol) + assert match + return int(match.group(3)) + + def sympy_dot(seq1: List[sympy.Expr], seq2: List[sympy.Expr]) -> sympy.Expr: assert len(seq1) == len(seq2) return sympy.expand(sum(a * b for a, b in zip(seq1, seq2))) diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py b/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py index c1b99e4859dbd..3213a8831ae22 100644 --- a/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py +++ b/orttraining/orttraining/python/training/ort_triton/kernel/__init__.py @@ -5,8 +5,10 @@ import os +import torch + from ._mm import triton_gemm, triton_gemm_out, triton_matmul, triton_matmul_out # noqa: F401 -from ._slice_scel import optimize_graph_for_slice_scel, slice_scel, slice_scel_backward # noqa: F401 +from ._slice_scel import slice_scel, slice_scel_backward # noqa: F401 _all_kernels = [ "triton_gemm", @@ -17,14 +19,14 @@ "slice_scel_backward", ] -_all_optimizers = [ - "optimize_graph_for_slice_scel", -] - -if "ORTMODULE_USE_FLASH_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1: - from ._flash_attn import flash_attn_backward, flash_attn_forward, optimize_graph_for_flash_attention # noqa: F401 +if ( + "ORTMODULE_USE_FLASH_ATTENTION" in os.environ + and int(os.getenv("ORTMODULE_USE_FLASH_ATTENTION")) == 1 + and torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 8 +): + from ._flash_attn import flash_attn_backward, flash_attn_forward # noqa: F401 _all_kernels.extend(["flash_attn_forward", "flash_attn_backward"]) - _all_optimizers.append("optimize_graph_for_flash_attention") -__all__ = _all_kernels + _all_optimizers # noqa: PLE0605 +__all__ = _all_kernels # noqa: PLE0605 diff --git a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py index b970c730d0441..1fe61750e651e 100644 --- a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py +++ b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py @@ -8,7 +8,7 @@ import os import sys from types import ModuleType -from typing import List, Tuple +from typing import List, Tuple, Union import onnx from torch._C import _from_dlpack @@ -18,8 +18,8 @@ from ._codegen import codegen from ._op_config import get_supported_ops from ._sorted_graph import SortedGraph -from ._sympy_utils import parse_shape -from ._utils import gen_unique_name +from ._sympy_utils import extract_shape_from_symbol, parse_shape +from ._utils import gen_unique_name, next_power_of_2 _DEBUG_MODE = "ORTMODULE_TRITON_DEBUG" in os.environ and int(os.getenv("ORTMODULE_TRITON_DEBUG")) == 1 @@ -31,11 +31,46 @@ def _gen_module_internal(sorted_graph: SortedGraph) -> Tuple[str, str, ModuleTyp return func_name, src_code, PyCodeCache().load(src_code) -def _gen_key(onnx_key: int, onnx_str: bytes, shapes: List[List[int]]) -> int: +class _ShapeCache: + """ + Cache the shapes of the inputs. The inputs are the concrete shapes of inputs from each step for a given ONNX model. + For those dimensions that the concrete shape is not changed, we use the same concrete shape. + For those dimensions that the concrete shape is changed between different steps, we use a symbolic shape. + """ + + cache = dict() # noqa: RUF012 + clear = staticmethod(cache.clear) + + @classmethod + def get_shape(cls, onnx_key: int, shapes: List[List[int]]) -> List[List[Union[int, str]]]: + if onnx_key not in cls.cache: + cls.cache[onnx_key] = shapes + else: + changed = False + for i, shape in enumerate(shapes): + for j, dim in enumerate(shape): + if dim != cls.cache[onnx_key][i][j] and isinstance(cls.cache[onnx_key][i][j], int): + max_dim = max(dim, cls.cache[onnx_key][i][j]) + shape[j] = f"i{i}_dim{j}_{next_power_of_2(max_dim)}" + changed = True + elif isinstance(cls.cache[onnx_key][i][j], str): + pre = extract_shape_from_symbol(cls.cache[onnx_key][i][j]) + if pre >= dim: + shape[j] = cls.cache[onnx_key][i][j] + else: + shape[j] = f"i{i}_dim{j}_{next_power_of_2(dim)}" + changed = True + if changed: + cls.cache[onnx_key] = shapes + return cls.cache[onnx_key] + + +def _gen_key(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> int: + # pylint: disable=unused-argument return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") % (10**8) -def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[int]]) -> Tuple[str, ModuleType]: +def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]: model = onnx.load_model_from_string(onnx_str) sorted_graph = SortedGraph(model, [parse_shape(shape) for shape in shapes]) if _DEBUG_MODE: @@ -44,7 +79,7 @@ def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[int]]) -> Tupl func_name, src_code, mod = _gen_module_internal(sorted_graph) if _DEBUG_MODE: py_file_path = f"triton_debug/{func_name}_{onnx_key}.py" - with open(py_file_path, "w") as f: + with open(py_file_path, "w", encoding="UTF-8") as f: f.write(src_code) return func_name, mod @@ -52,7 +87,8 @@ def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[int]]) -> Tupl def get_config() -> str: """ Get the supported ops and other configs in JSON format to control the Triton fusion on backend side. - All supported ops are from _op_config.py. The Triton fusion will try to fuse subgraphs with connected supported ops. + All supported ops are from user config specified by env ORTMODULE_TRITON_CONFIG_FILE or from _op_config.py. + The Triton fusion will try to fuse subgraphs with connected supported ops. The initializer value can be "none", "scalar", and "all". "none": no initializer will be added to subgraphs. "scalar": only related scalar initializers will be added to subgraphs. @@ -60,6 +96,11 @@ def get_config() -> str: The min_nodes is used to control the minimum number of non-no-op nodes in a subgraph. """ + config_file = os.getenv("ORTMODULE_TRITON_CONFIG_FILE", "") + if config_file and os.path.exists(config_file): + with open(config_file, encoding="UTF-8") as f: + return f.read() + config = {"ops": get_supported_ops(), "initializer": "scalar", "min_nodes": 2} return json.dumps(config) @@ -90,7 +131,8 @@ def call_triton_by_onnx(onnx_key: int, onnx_str: bytes, *tensors): assert all(tensor is not None for tensor in tensors) torch_tensors = [_from_dlpack(tensor) for tensor in tensors] concrete_shapes = [list(tensor.size()) for tensor in torch_tensors] - func_name, mod = ModuleCache.load(_gen_key, _gen_module, onnx_key, onnx_str, concrete_shapes) + shapes = _ShapeCache.get_shape(onnx_key, concrete_shapes) + func_name, mod = ModuleCache.load(_gen_key, _gen_module, onnx_key, onnx_str, shapes) func = getattr(mod, func_name) output = func(*torch_tensors) if isinstance(output, tuple): diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 4977272de5ac9..8efbe16d7d61d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -412,14 +412,24 @@ def _matmul4bit_export(g, n, *args, **kwargs): return None quant_state = args[4] - absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state + if isinstance(quant_state, list): + # version <= 0.41.1 + absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state + nested = compressed_stats is not None + else: + # version > 0.41.1 + absmax = quant_state.absmax + shape = quant_state.shape + blocksize = quant_state.blocksize + nested = quant_state.nested + quant_type = quant_state.quant_type # MatMulBnb4's blocksize needs to be a power of 2 and not smaller than 16 if blocksize < 16 or blocksize & (blocksize - 1) != 0: return None # MatMulBnb4 does not support double de-quantization (e.g. absmax is int, needs to be dequantized too) - if compressed_stats is not None: + if nested: return None # The PyTorch linear weight shape is [out_feature, in_feature] diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py index d215e12f8137a..3d3538a62da61 100644 --- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py @@ -5,9 +5,16 @@ import os +import torch +from packaging.version import Version + _all_optimizers = [] -if "ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1: +if ( + "ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ + and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1 + and Version(torch.__version__) >= Version("2.1.1") +): from ._aten_attn import optimize_graph_for_aten_efficient_attention # noqa: F401 _all_optimizers.append("optimize_graph_for_aten_efficient_attention") diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py index 94bd41293b427..b1e8809f03fc0 100644 --- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py @@ -245,31 +245,25 @@ def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: List[NodePro ("MatMul", False, []), # 0 ("Mul", True, [(0, 0, 0)]), # 1 ("Mul", True, [(0, 0, 1)]), # 2 - ("Cast", True, [(1, 0, 0)]), # 3 - ("Cast", True, [(2, 0, 0)]), # 4 - ("Transpose", True, [(3, 0, 0)]), # 5 - ("Transpose", True, [(4, 0, 0)]), # 6 - ("Softmax", False, [(0, 0, 0)]), # 7 - ("Cast", False, [(7, 0, 0)]), # 8 - ("MatMul", False, [(8, 0, 0)]), # 9 - ("Transpose", True, [(9, 0, 1)]), # 10 - ("Transpose", False, [(9, 0, 0)]), # 11 - ("FusedMatMul", False, [(10, 0, 1)]), # 12 - ("Cast", False, [(12, 0, 0)]), # 13 - ("SoftmaxGrad_13", False, [(13, 0, 0), (7, 0, 1)]), # 14 - ("FusedMatMul", False, [(2, 0, 1), (14, 0, 0)]), # 15 - ("FusedMatMul", False, [(1, 0, 0), (14, 0, 1)]), # 16 - ("Mul", False, [(15, 0, 0)]), # 17 - ("Mul", False, [(16, 0, 0)]), # 18 - ("Identity", False, [(17, 0, 0)]), # 19 - ("Identity", False, [(18, 0, 0)]), # 20 - ("Cast", False, [(19, 0, 0)]), # 21 - ("Cast", False, [(20, 0, 0)]), # 22 - ("Transpose", False, [(21, 0, 0)]), # 23 - ("Transpose", False, [(22, 0, 0)]), # 24 - ("FusedMatMul", False, [(8, 0, 0)]), # 25 - ("Transpose", True, [(25, 0, 1)]), # 26 - ("Transpose", False, [(25, 0, 0)]), # 27 + ("Transpose", True, [(1, 0, 0)]), # 3 + ("Transpose", True, [(2, 0, 0)]), # 4 + ("Softmax", False, [(0, 0, 0)]), # 5 + ("MatMul", False, [(5, 0, 0)]), # 6 + ("Transpose", True, [(6, 0, 1)]), # 7 + ("Transpose", False, [(6, 0, 0)]), # 8 + ("FusedMatMul", False, [(7, 0, 1)]), # 9 + ("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10 + ("FusedMatMul", False, [(2, 0, 1), (10, 0, 0)]), # 11 + ("FusedMatMul", False, [(1, 0, 0), (10, 0, 1)]), # 12 + ("Mul", False, [(11, 0, 0)]), # 13 + ("Mul", False, [(12, 0, 0)]), # 14 + ("Identity", False, [(13, 0, 0)]), # 15 + ("Identity", False, [(14, 0, 0)]), # 16 + ("Transpose", False, [(15, 0, 0)]), # 17 + ("Transpose", False, [(16, 0, 0)]), # 18 + ("FusedMatMul", False, [(5, 0, 0)]), # 19 + ("Transpose", True, [(19, 0, 1)]), # 20 + ("Transpose", False, [(19, 0, 0)]), # 21 ] @@ -280,27 +274,24 @@ def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodePro scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 if not ( - check_attribute_value(nodes[3], "to", 1) - and check_attribute_value(nodes[4], "to", 1) - and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1]) - and check_attribute_value(nodes[8], "to", 10) - and check_attribute_value(nodes[10], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[11], "perm", [0, 2, 1, 3]) + check_attribute_value(nodes[3], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1]) + and check_attribute_value(nodes[7], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3]) and scale_value_1 == scale_value_2 ): return [], [], [] nodes_to_add, new_value_infos = _make_efficient_attention_nodes( idx, - nodes[5].input[0], - nodes[6].input[0], - nodes[10].input[0], - nodes[11].output[0], - nodes[26].input[0], - nodes[23].output[0], - nodes[24].output[0], - nodes[27].output[0], + nodes[3].input[0], + nodes[4].input[0], + nodes[7].input[0], + nodes[8].output[0], + nodes[20].input[0], + nodes[17].output[0], + nodes[18].output[0], + nodes[21].output[0], "", False, scale_value_1, @@ -315,39 +306,32 @@ def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodePro ("MatMul", False, []), # 0 ("Mul", True, [(0, 0, 0)]), # 1 ("Mul", True, [(0, 0, 1)]), # 2 - ("Cast", True, [(1, 0, 0)]), # 3 - ("Cast", True, [(2, 0, 0)]), # 4 - ("Transpose", True, [(3, 0, 0)]), # 5 - ("Transpose", True, [(4, 0, 0)]), # 6 - ("Add", False, [(0, 0, 0)]), # 7 - ("Cast", True, [(7, 0, 1)]), # 8 - ("Slice", True, [(8, 0, 0)]), # 9 - ("Slice", True, [(9, 0, 0)]), # 10 - ("Unsqueeze", True, [(9, 0, 2)]), # 11 - ("Gather", True, [(11, 0, 0)]), # 12 - ("Shape", True, [(12, 0, 0)]), # 13 - ("Softmax", False, [(7, 0, 0)]), # 14 - ("Cast", False, [(14, 0, 0)]), # 15 - ("MatMul", False, [(15, 0, 0)]), # 16 - ("Transpose", True, [(16, 0, 1)]), # 17 - ("Transpose", False, [(16, 0, 0)]), # 18 - ("FusedMatMul", False, [(17, 0, 1)]), # 19 - ("Cast", False, [(19, 0, 0)]), # 20 - ("SoftmaxGrad_13", False, [(20, 0, 0), (14, 0, 1)]), # 21 - ("Identity", False, [(21, 0, 0)]), # 22 - ("FusedMatMul", False, [(2, 0, 1), (22, 0, 0)]), # 23 - ("FusedMatMul", False, [(1, 0, 0), (22, 0, 1)]), # 24 - ("Mul", False, [(23, 0, 0)]), # 25 - ("Mul", False, [(24, 0, 0)]), # 26 - ("Identity", False, [(25, 0, 0)]), # 27 - ("Identity", False, [(26, 0, 0)]), # 28 - ("Cast", False, [(27, 0, 0)]), # 29 - ("Cast", False, [(28, 0, 0)]), # 30 - ("Transpose", False, [(29, 0, 0)]), # 31 - ("Transpose", False, [(30, 0, 0)]), # 32 - ("FusedMatMul", False, [(15, 0, 0)]), # 33 - ("Transpose", True, [(33, 0, 1)]), # 34 - ("Transpose", False, [(33, 0, 0)]), # 35 + ("Transpose", True, [(1, 0, 0)]), # 3 + ("Transpose", True, [(2, 0, 0)]), # 4 + ("Add", False, [(0, 0, 0)]), # 5 + ("Slice", True, [(5, 0, 1)]), # 6 + ("Slice", True, [(6, 0, 0)]), # 7 + ("Unsqueeze", True, [(6, 0, 2)]), # 8 + ("Gather", True, [(8, 0, 0)]), # 9 + ("Shape", True, [(9, 0, 0)]), # 10 + ("Softmax", False, [(5, 0, 0)]), # 11 + ("MatMul", False, [(11, 0, 0)]), # 12 + ("Transpose", True, [(12, 0, 1)]), # 13 + ("Transpose", False, [(12, 0, 0)]), # 14 + ("FusedMatMul", False, [(13, 0, 1)]), # 15 + ("SoftmaxGrad_13", False, [(15, 0, 0), (11, 0, 1)]), # 16 + ("Identity", False, [(16, 0, 0)]), # 17 + ("FusedMatMul", False, [(2, 0, 1), (17, 0, 0)]), # 18 + ("FusedMatMul", False, [(1, 0, 0), (17, 0, 1)]), # 19 + ("Mul", False, [(18, 0, 0)]), # 20 + ("Mul", False, [(19, 0, 0)]), # 21 + ("Identity", False, [(20, 0, 0)]), # 22 + ("Identity", False, [(21, 0, 0)]), # 23 + ("Transpose", False, [(22, 0, 0)]), # 24 + ("Transpose", False, [(23, 0, 0)]), # 25 + ("FusedMatMul", False, [(11, 0, 0)]), # 26 + ("Transpose", True, [(26, 0, 1)]), # 27 + ("Transpose", False, [(26, 0, 0)]), # 28 ] @@ -358,27 +342,24 @@ def _optimize_for_pattern_3(matcher: GraphMatcher, idx: int, nodes: List[NodePro scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 if not ( - check_attribute_value(nodes[3], "to", 1) - and check_attribute_value(nodes[4], "to", 1) - and check_attribute_value(nodes[5], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[6], "perm", [0, 2, 3, 1]) - and check_attribute_value(nodes[15], "to", 10) - and check_attribute_value(nodes[17], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[18], "perm", [0, 2, 1, 3]) + check_attribute_value(nodes[3], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1]) + and check_attribute_value(nodes[13], "perm", [0, 2, 1, 3]) + and check_attribute_value(nodes[14], "perm", [0, 2, 1, 3]) and scale_value_1 == scale_value_2 ): return [], [], [] nodes_to_add, new_value_infos = _make_efficient_attention_nodes( idx, - nodes[5].input[0], - nodes[6].input[0], - nodes[17].input[0], - nodes[18].output[0], - nodes[34].input[0], - nodes[31].output[0], - nodes[32].output[0], - nodes[35].output[0], + nodes[3].input[0], + nodes[4].input[0], + nodes[13].input[0], + nodes[14].output[0], + nodes[27].input[0], + nodes[24].output[0], + nodes[25].output[0], + nodes[28].output[0], "", False, scale_value_1, diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py deleted file mode 100644 index d5a488c436a1d..0000000000000 --- a/orttraining/orttraining/python/training/orttrainer.py +++ /dev/null @@ -1,1537 +0,0 @@ -import copy -import io -import os -import warnings -from functools import partial -from inspect import signature - -import numpy as np -import onnx -import torch - -import onnxruntime as ort -from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference - -from . import _checkpoint_storage, _utils, amp, checkpoint, optim, postprocess -from .model_desc_validation import _ORTTrainerModelDesc -from .orttrainer_options import ORTTrainerOptions - - -class TrainStepInfo: - r"""Private class used to store runtime information from current train step. - - After every train step, :py:meth:`ORTTrainer.train_step` updates the internal instance of - :py:class:`.TrainStepInfo` residing on :py:class:`.ORTTrainer` with relevant information - from the forward pass. - - This class shouldn't be accessed directly by the user, unless they really know what they are doing. - Instead, :py:class:`.ORTTrainer` passes it to relevant class methods automatically, - such as :py:method:`._LRScheduler.get_lr` or :py:class:`.LossScaler.update`. - - Args: - optimizer_config (optim._OptimizerConfig): reference to optimizer config - all_finite (bool, default is True): flag that indicates whether all gradients are still finite after last step - fetches (list of str, default is []): list of output names to fetch from train_step/eval_step. Set it to [] to reset normal behavior. - optimization_step (int): indicates the number of optimizations performed. Used for learning rate scheduling - step (int): indicates current training step. Used for gradient accumulation - - Example: - - .. code-block:: python - - info = TrainStepInfo(optimizer_config=optim.SGDConfig(lr=0.01)) - if info.all_finite: - print(f'Yay, all gradients are finite at {step} step!') - - """ - - def __init__(self, optimizer_config, all_finite=True, fetches=[], optimization_step=0, step=0): # noqa: B006 - assert isinstance(optimizer_config, optim._OptimizerConfig), "optimizer_config must be a optim._OptimizerConfig" - assert isinstance(all_finite, bool), "all_finite must be a bool" - assert isinstance(fetches, list) and all( - [isinstance(item, str) for item in fetches] - ), "fetches must be a list of str" - assert isinstance(optimization_step, int) and optimization_step >= 0, "optimization_step must be a positive int" - assert isinstance(step, int) and step >= 0, "step must be a positive int" - - self.optimizer_config = optimizer_config - self.all_finite = all_finite - self.fetches = fetches - self.optimization_step = optimization_step - self.step = step - - -class ORTTrainer: - r"""Pytorch frontend for ONNX Runtime training - - Entry point that exposes the C++ backend of ORT as a Pytorch frontend. - - Args: - model (torch.nn.Module or onnx.ModelProto): either a PyTorch or ONNX model. - When a PyTorch model and :py:attr:`loss_fn` are specified, :py:attr:`model` and :py:obj:`loss_fn` are combined. - When a ONNX model is provided, the loss is identified by the flag :py:obj:`is_loss=True` in one of the :py:attr:`.model_desc.outputs` entries. - model_desc (dict): model input and output description. - This is used to identify inputs and outputs and their shapes, so that ORT can generate back propagation graph, plan memory allocation for - training, and perform optimizations. - :py:attr:`model_desc` must be consistent with the training :py:attr:`model` and have the following (:py:obj:`dict`) schema - :py:obj:`{ 'inputs': [tuple(name, shape)], 'outputs': [tuple(name, shape, is_loss)]}`. - :py:attr:`name` is a string representing the name of input or output of the model. - For :py:obj:`model_desc['inputs']` entries, :py:attr:`name` must match input names of the original PyTorch model's :py:meth:`torch.nn.Module.forward` method. - For ONNX models, both name and order of input names must match. - For :py:obj:`model_desc['outputs']` entries, the order must match the original PyTorch's output as returned by :py:meth:`torch.nn.Module.forward` method. - For ONNX models, both name and order of output names must match. - :py:attr:`shape` is a list of string or integers that describes the shape of the input/output. - Each dimension size can be either a string or an int. String means the dimension size is dynamic, while integers mean static dimensions. - An empty list implies a scalar. - Lastly, :py:attr:`is_loss` is a boolean (default is False) that flags if this output is considered a loss. - ORT backend needs to know which output is loss in order to generate back propagation graph. - Loss output must be specified when either :py:attr:`loss_fn` is specified or when loss is embedded in the model. - Note that only one loss output is supported per model. - optimizer_config (optim._OptimizerConfig): optimizer config. - One of :py:class:`.optim.AdamConfig`, :py:class:`.optim.LambConfig` or :py:class:`.optim.SGDConfig`. - loss_fn (callable, default is None): a PyTorch loss function. - It takes two inputs [prediction, label] and outputs a scalar loss tensor. - If provided, :py:attr:`loss_fn` is combined with the PyTorch :py:attr:`model` to form a combined PyTorch model. - Inputs to the combined PyTorch model are concatenation of the :py:attr:`model`'s input and :py:attr:`loss_fn`'s label input. - Outputs of the combined PyTorch model are concatenation of :py:attr:`loss_fn`'s loss output and :py:attr:`model`'s outputs. - options (ORTTrainerOptions, default is None): options for additional features. - Example: - - .. code-block:: python - - model = ... - loss_fn = ... - model_desc = { - "inputs": [ - ("input_ids", ["batch", "max_seq_len_in_batch"]), - ("attention_mask", ["batch", "max_seq_len_in_batch"]), - ("token_type_ids", ["batch", "max_seq_len_in_batch"]), - ("masked_lm_labels", ["batch", "max_seq_len_in_batch"]), - ("next_sentence_label", ["batch", 1]) - ], - "outputs": [ - ("loss", [], True), - ], - } - optim_config = optim.LambConfig(param_groups = [ { 'params' : ['model_param0'], 'alpha' : 0.8, 'beta' : 0.7}, - { 'params' : ['model_param1' , 'model_param_2'], 'alpha' : 0.0} - ], - alpha=0.9, beta=0.999) - ort_trainer = ORTTrainer(model, model_desc, optim_config, loss_fn) - """ - - def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None): - warnings.warn( - "ORTTrainer is deprecated and will be removed in ort release 1.14. Please use ORTModule instead.", - FutureWarning, - ) - - assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model" - assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'" - assert isinstance( - optim_config, optim._OptimizerConfig - ), "'optim_config' is required and must be any of 'AdamConfig', 'LambConfig' or 'SGDConfig'" - assert loss_fn is None or ( - callable(loss_fn) and len(signature(loss_fn).parameters) == 2 - ), "'loss_fn' must be either 'None' or a callable with two parameters" - assert options is None or isinstance( - options, ORTTrainerOptions - ), "'options' must be either 'None' or 'ORTTrainerOptions'" - - # Model + Loss validation - # Supported combinarios are - # ---------------------------------------- - # | | Model | Loss | - # ---------------------------------------- - # | 1 | torch.nn.Module | None | - # | 2 | torch.nn.Module | torch.nn.Module | - # | 3 | ONNX | None | - # ---------------------------------------- - self._torch_model = None - self._onnx_model = None - if isinstance(model, torch.nn.Module): - assert loss_fn is None or isinstance( - model, torch.nn.Module - ), "'loss_fn' must be either 'None' or 'torch.nn.Module'" - self._torch_model = model - self.loss_fn = loss_fn - # TODO: Remove when experimental checkpoint functions are removed. - self._torch_state_dict_keys = list(model.state_dict().keys()) - elif isinstance(model, onnx.ModelProto): - assert loss_fn is None, "'loss_fn' must not be specified when 'model' is an ONNX model" - self._onnx_model = model - self.loss_fn = None - else: - raise ValueError("'model' must be either 'torch.nn.Module' or 'onnx.ModelProto'") - - self.model_desc = _ORTTrainerModelDesc(model_desc) - self.optim_config = optim_config - - # ORTTrainerOptions - if not options: - options = ORTTrainerOptions() - self.options = options - if self.options.mixed_precision.enabled and not self.options.mixed_precision.loss_scaler: - # TODO: Move this to model_desc_validation.py - self.options.mixed_precision.loss_scaler = amp.loss_scaler.DynamicLossScaler() - # Post processing ONNX model given as input - if self._onnx_model: - if self.options._internal_use.enable_internal_postprocess: - self._onnx_model = postprocess.run_postprocess(self._onnx_model) - if self.options._internal_use.extra_postprocess: - self._onnx_model = self.options._internal_use.extra_postprocess(self._onnx_model) - assert isinstance(self._onnx_model, onnx.ModelProto), "'extra_postprocess' must return a ONNX model" - - # When input model is already ONNX (and not exported from Pytorch within ORTTrainer), - # append 'dtype' from ONNX into model description's - for idx_i, i_desc in enumerate(self.model_desc.inputs): - dtype = None - for onnx_input in self._onnx_model.graph.input: - if onnx_input.name == i_desc.name: - dtype = _utils.dtype_onnx_to_torch(onnx_input.type.tensor_type.elem_type) - self.model_desc.add_type_to_input_description(idx_i, dtype) - break - assert dtype is not None, f"ONNX model with unknown input type ({i_desc.name})" - for idx_o, o_desc in enumerate(self.model_desc.outputs): - dtype = None - for onnx_output in self._onnx_model.graph.output: - if onnx_output.name == o_desc.name: - dtype = _utils.dtype_onnx_to_torch(onnx_output.type.tensor_type.elem_type) - self.model_desc.add_type_to_output_description(idx_o, dtype) - break - assert dtype is not None, f"ONNX model with unknown output type ({o_desc.name})" - - try: - from torch.utils.cpp_extension import ROCM_HOME - - self.is_rocm_pytorch = bool(torch.version.hip is not None and ROCM_HOME is not None) - except ImportError: - self.is_rocm_pytorch = False - - # TODO: Remove when experimental checkpoint functions are removed. - self._state_dict = {} - - self._train_step_info = TrainStepInfo(self.optim_config) - self._training_session = None - self._load_state_dict = None - self._init_session( - provider_options=self.options._validated_opts["provider_options"], - session_options=self.options.session_options, - ) - - def eval_step(self, *args, **kwargs): - r"""Evaluation step method - - Args: - *args: Arbitrary arguments that are used as model input (data only) - **kwargs: Arbitrary keyword arguments that are used as model input (data only) - - Returns: - ordered :py:obj:`list` with model outputs as described by :py:attr:`.ORTTrainer.model_desc` - """ - # Get data. CombineTorchModelLossFn takes label as last input and outputs loss first - sample_input = self._prepare_model_input(self.model_desc.inputs, None, None, *args, **kwargs) - - # Export model to ONNX - if self._onnx_model is None: - if self._torch_model is not None: - self._init_onnx_model(sample_input) - else: - raise RuntimeError("Model is uninitialized. Only ONNX and PyTorch models are supported") - - # Prepare input/output description - inputs_desc = self.model_desc.inputs - outputs_desc = self.model_desc.outputs - if self._train_step_info.fetches: - outputs_desc = [o_desc for o_desc in outputs_desc if o_desc.name in self._train_step_info.fetches] - if len(outputs_desc) != len(self._train_step_info.fetches): - raise RuntimeError("The specified fetches list contains invalid output names") - - # Normalize input - if not isinstance(sample_input, (list, tuple)): - sample_input = (sample_input,) - - # RunOptions - run_options = ort.RunOptions() - run_options.only_execute_path_to_fetches = True - run_options.training_mode = False - - # Run a eval step and return - session_run_results = self._training_session_run_helper( - False, sample_input, inputs_desc, outputs_desc, run_options - ) - - # Output must be returned in the same order as defined in the model description - results = [session_run_results[o_desc.name] for o_desc in outputs_desc] - return results[0] if len(results) == 1 else results - - def save_as_onnx(self, path): - r"""Persists ONNX model into :py:attr:`path` - - The model will be saved as a Google Protocol Buffers (aka protobuf) file as per ONNX standard. - The graph includes full information, including inference and training metadata. - - Args: - path (str): Full path, including filename, to save the ONNX model in the filesystem - - Raises: - RuntimeWarning: raised when neither `train_step` or `eval_step` was called at least once - ValueError: raised when `path` is not valid path - """ - if not self._training_session: - warnings.warn( - "Training session is not initialized yet. " - "'train_step' or 'eval_step' methods must be executed at least once before calling 'save_as_onnx()'." - ) - return - state_tensors = self._training_session.get_state() - self._update_onnx_model_initializers(state_tensors) - - assert isinstance(path, str), "'path' must be a valid path string" - dir_name = os.path.dirname(path) - file_name = os.path.basename(path) - if (dir_name and not os.path.exists(dir_name)) or not file_name: - warnings.warn("'path' is not valid or does not exist") - return - - with open(path, "wb") as f: - f.write(self._onnx_model.SerializeToString()) - - def _check_model_export(self, input): - from numpy.testing import assert_allclose - from onnx import TensorProto, helper, numpy_helper # noqa: F401 - - onnx_model_copy = copy.deepcopy(self._onnx_model) - - # Mute the dropout nodes - dropout_nodes = [n for n in onnx_model_copy.graph.node if n.op_type == "Dropout"] - for node in dropout_nodes: - ratio_node = next(n for n in onnx_model_copy.graph.node if node.input[1] in n.output) - training_mode_node = next(n for n in onnx_model_copy.graph.node if node.input[2] in n.output) - - training_mode_node.attribute.pop() - ratio_node.attribute.pop() - new_training_mode_arr = np.array(False, dtype=bool) - new_ratio_arr = np.array(0.0, dtype=np.float32) - new_training_mode = numpy_helper.from_array(new_training_mode_arr) - new_ratio = numpy_helper.from_array(new_ratio_arr) - training_mode_node.attribute.add().t.CopyFrom(new_training_mode) - ratio_node.attribute.add().t.CopyFrom(new_ratio) - training_mode_node.attribute[0].type = 4 - ratio_node.attribute[0].type = 4 - training_mode_node.attribute[0].name = "value" - ratio_node.attribute[0].name = "value" - - _inference_sess = ort.InferenceSession( - onnx_model_copy.SerializeToString(), providers=ort.get_available_providers() - ) - inf_inputs = {} - for i, input_elem in enumerate(input): - inf_inputs[_inference_sess.get_inputs()[i].name] = input_elem.cpu().numpy() - _inference_outs = _inference_sess.run(None, inf_inputs) - for torch_item, ort_item in zip(self.torch_sample_outputs, _inference_outs): - assert_allclose( - torch_item, - ort_item, - rtol=1e-2, - atol=1e-6, - err_msg="Mismatch between outputs of PyTorch model and exported ONNX model. " - "Note that different backends may exhibit small computational differences." - "If this is within acceptable margin, or if there is random generator " - "in the model causing inevitable mismatch, you can proceed training by " - "setting the flag debug.check_model_export to False.", - ) - - def train_step(self, *args, **kwargs): - r"""Train step method - - After forward pass, an ordered list with all outputs described at :py:attr:`ORTTrainer.model_desc` is returned. - Additional information relevant to the train step is maintend by :py:attr:`ORTTrainer._train_step_info`. - See :py:class:`.TrainStepInfo` for details. - - Args: - *args: Arbitrary arguments that are used as model input (data only) - **kwargs: Arbitrary keyword arguments that are used as model input (data only) - - Returns: - ordered :py:obj:`list` with model outputs as described by :py:attr:`ORTTrainer.model_desc` - """ - # Export model to ONNX - if self._onnx_model is None: - sample_input = self._prepare_model_input(self.model_desc.inputs, None, None, *args, **kwargs) - self._init_onnx_model(sample_input) - - # Debug Model Export if indicated - if self.options.debug.check_model_export: - self._check_model_export(sample_input) - - # Prepare inputs+lr and output descriptions - inputs_desc = self._model_desc_inputs_with_lr - outputs_desc = self.model_desc.outputs - - # Train step must be incremented *before* gradient accumulation code - # Gradients are accumulated when - # self._train_step_info.step % self.options.batch.gradient_accumulation_steps != 0, - # and they are updated otherwise - self._train_step_info.step += 1 - - # RunOptions - run_options = None - mixed_precision_without_fetches = False - if self._train_step_info.fetches: - outputs_desc = [o_desc for o_desc in outputs_desc if o_desc.name in self._train_step_info.fetches] - if len(outputs_desc) != len(self._train_step_info.fetches): - raise RuntimeError("The specified fetches list contains invalid output names") - elif self._train_step_info.step % self.options.batch.gradient_accumulation_steps != 0: - run_options = ort.RunOptions() - run_options.only_execute_path_to_fetches = True - outputs_desc = self._model_desc_outputs_with_gradient_accumulation - elif self.options.mixed_precision.enabled: - mixed_precision_without_fetches = True - outputs_desc = self._model_desc_outputs_with_all_finite - - # Update Learning Rate if Necessary - lr = self.optim_config.lr - if self.options.lr_scheduler: - lr = self.options.lr_scheduler._step(self._train_step_info)[0] - - # Loss Scale for mixed precision - loss_scale = None - if self.options.mixed_precision.enabled: - loss_scaler = self.options.mixed_precision.loss_scaler - assert loss_scaler, "Loss scaler is required when mixed precision is enabled" - loss_scale = loss_scaler.loss_scale - inputs_desc = self._model_desc_inputs_with_lr_and_loss_scale - - # Get data. CombineTorchModelLossFn takes label as last input and outputs loss first - input = self._prepare_model_input(inputs_desc, lr, loss_scale, *args, **kwargs) - - # Normalize input - if not isinstance(args, (list, tuple)): - args = (args,) - - # Run a train step and return - session_run_results = self._training_session_run_helper(True, input, inputs_desc, outputs_desc, run_options) - if mixed_precision_without_fetches: - # After session run with all_fp32_gradients_finite, we need to clear the training I/O binding's output - # Otherwise next run with only_execute_path_to_fetches will lead to gradient all reduce - # because all_fp32_gradients_finite is still in the feed. - self._train_io_binding.clear_binding_outputs() - - is_all_finite = session_run_results[self.model_desc.all_finite.name] - self._train_step_info.all_finite = is_all_finite - if loss_scaler: - loss_scaler.update(self._train_step_info) - if is_all_finite: - # Optimization step must be incremented *after* optimization is successful - self._train_step_info.optimization_step += 1 - elif self._train_step_info.step % self.options.batch.gradient_accumulation_steps == 0: - # Optimization step must be incremented *after* optimization is successful - self._train_step_info.optimization_step += 1 - - # Output must be returned in the same order as defined in the model description - # or in the order specified by TrainStepInfo.fetches, if applicable - if self._train_step_info.fetches: - results = [session_run_results[o_desc] for o_desc in self._train_step_info.fetches] - else: - results = [session_run_results[o_desc.name] for o_desc in self.model_desc.outputs] - return results[0] if len(results) == 1 else results - - def _convert_torch_model_loss_fn_to_onnx(self, inputs, device): - # Dynamic axes - dynamic_axes = {} - for input in self.model_desc.inputs: - symbolic_axis = {} - for i, axis in enumerate(input.shape): - if isinstance(axis, str): - symbolic_axis[i] = axis - if len(symbolic_axis): - dynamic_axes[input.name] = symbolic_axis - for output in self.model_desc.outputs: - symbolic_axis = {} - for i, axis in enumerate(output.shape): - if isinstance(axis, str): - symbolic_axis[i] = axis - if len(symbolic_axis): - dynamic_axes[output.name] = symbolic_axis - - if isinstance(inputs, torch.Tensor): - inputs = [inputs] - if isinstance(inputs, dict): - sample_inputs = [inputs[k.name_].to(device=device) for k in self.model_desc.inputs] - elif isinstance(inputs, (list, tuple)): - sample_inputs = [ - input.to(device=device) for i, input in enumerate(inputs) if i < len(self.model_desc.inputs) - ] - else: - raise RuntimeError( - "Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported." - ) - - # PyTorch ONNX exporter does not match argument names - # This is an issue because the ONNX graph depends on all inputs to be specified - - # Validate loss_fn - if self.loss_fn: - sig_loss = signature(self.loss_fn) - if len(sig_loss.parameters) != 2: - raise RuntimeError("loss function should take two arguments - predict and label.") - - # Basic input names from model - input_names = [input.name for input in self.model_desc.inputs] - sig = signature(self._torch_model.forward) - ordered_input_list = list(sig.parameters.keys()) - - # Label from loss_fn goes after model input - if self.loss_fn: - ordered_input_list = [*ordered_input_list, list(sig_loss.parameters.keys())[1]] - - class CombineTorchModelLossFnWrapInput(torch.nn.Module): - def __init__(self, model, loss_fn, input_names): - super().__init__() - self.model = model - self.loss_fn = loss_fn - self.input_names = input_names - - def forward(self, *inputs): - sig = signature(self.model.forward) - - input_dict = {} - for key in sig.parameters: - if key in self.input_names: - input_dict[key] = inputs[self.input_names.index(key)] - - model_out = self.model(**input_dict) - if self.loss_fn is None: - return model_out - - label = inputs[-1] - preds = model_out - return self.loss_fn(preds, label), preds - - model = CombineTorchModelLossFnWrapInput(self._torch_model, self.loss_fn, input_names) - - # Do an inference to grab output types - model.eval() - with torch.no_grad(): - # Deepcopy inputs, since input values may change after model run. - sample_inputs_copy = copy.deepcopy(sample_inputs) - try: - # Deepcopy model, in case model is stateful and changes after model run. - model_copy = copy.deepcopy(model) - except Exception: - model_copy = model - warnings.warn( - "This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX." - " Compute will continue, but unexpected results may occur!" - ) - sample_outputs = model_copy(*sample_inputs_copy) - self.torch_sample_outputs = sample_outputs - model.train() - - if isinstance(sample_outputs, torch.Tensor): - sample_outputs = [sample_outputs] - - # Append 'dtype' for model description's inputs/outputs - for idx_i, sample_input in enumerate(sample_inputs): - if idx_i < len(self.model_desc.inputs): - self.model_desc.add_type_to_input_description(idx_i, sample_input.dtype) - for idx_o, sample_output in enumerate(sample_outputs): - if idx_o < len(self.model_desc.outputs): - self.model_desc.add_type_to_output_description(idx_o, sample_output.dtype) - - # Export the model to ONNX - f = io.BytesIO() - - # Deepcopy inputs, since input values may change after model run. - sample_inputs_copy = copy.deepcopy(sample_inputs) - - # Handle contrib OPs support - from onnxruntime.tools import pytorch_export_contrib_ops - - if self.options._internal_use.enable_onnx_contrib_ops: - pytorch_export_contrib_ops.register() - else: - # Unregister in case they were registered in previous calls. - pytorch_export_contrib_ops.unregister() - - # Export torch.nn.Module to ONNX - torch.onnx.export( - model, - tuple(sample_inputs_copy), - f, - input_names=[input.name for input in self.model_desc.inputs], - output_names=[output.name for output in self.model_desc.outputs], - opset_version=self.options._internal_use.onnx_opset_version, - dynamic_axes=dynamic_axes, - do_constant_folding=False, - training=torch.onnx.TrainingMode.TRAINING, - ) - onnx_model = onnx.load_model_from_string(f.getvalue()) - - # Remove 'model.' prefix introduced by CombineTorchModelLossFn class - if isinstance(model, CombineTorchModelLossFnWrapInput): - replace_name_dict = {} - for n in onnx_model.graph.initializer: - if n.name.startswith("model."): - replace_name_dict[n.name] = n.name[len("model.") :] - n.name = replace_name_dict[n.name] - for n in onnx_model.graph.node: - for i, name in enumerate(n.input): - if name in replace_name_dict: - n.input[i] = replace_name_dict[name] - - return onnx_model - - def _create_ort_training_session(self, optimizer_state_dict=None, session_options=None, provider_options=None): - if optimizer_state_dict is None: - optimizer_state_dict = {} - # Validating frozen_weights names - unused_frozen_weights = [ - n - for n in self.options.utils.frozen_weights - if n not in [i.name for i in self._onnx_model.graph.initializer] - ] - if unused_frozen_weights: - raise RuntimeError(f"{unused_frozen_weights} params from 'frozen_weights' not found in the ONNX model.") - - # Get loss name from model description - loss_name = [item.name for item in self.model_desc.outputs if item.is_loss] - assert len(loss_name) == 1, f"Only one loss output is supported ({len(loss_name)} were specified)" - loss_name = loss_name[0] - - # Parse optimizer parameters - optimizer_attributes_map = {} - optimizer_int_attributes_map = {} - trainable_params = set() - for initializer in self._onnx_model.graph.initializer: - if initializer.name in self.options.utils.frozen_weights: - continue # only trainable parameters are passed to the backend - trainable_params.add(initializer.name) - optimizer_attributes_map[initializer.name] = {} - optimizer_int_attributes_map[initializer.name] = {} - not_in_param_groups = True - for param_group in self.optim_config.params: - if initializer.name not in param_group["params"]: - continue # keep looking for a matching param_group - not_in_param_groups = False - for k, v in param_group.items(): - # 'params' is not a hyper parameter, skip it. 'lr' per weight is not supported - if k == "params" or k == "lr": - continue - if isinstance(v, float): - optimizer_attributes_map[initializer.name][k] = v - elif isinstance(v, int): - optimizer_int_attributes_map[initializer.name][k] = v - else: - raise ValueError("Optimizer attributes must be either float or int.") - - # set default values for params not found in groups - if not_in_param_groups: - for k, v in self.optim_config.defaults.items(): - if k == "lr": - continue - if isinstance(v, float): - optimizer_attributes_map[initializer.name][k] = v - elif isinstance(v, int): - optimizer_int_attributes_map[initializer.name][k] = v - else: - raise ValueError("Optimizer attributes must be either float or int.") - - self.options.distributed.horizontal_parallel_size = max(self.options.distributed.horizontal_parallel_size, 1) - self.options.distributed.data_parallel_size = ( - self.options.distributed.world_size // self.options.distributed.horizontal_parallel_size - ) - - # TrainingParameters - ort_parameters = ort.TrainingParameters() - ort_parameters.loss_output_name = loss_name - ort_parameters.use_mixed_precision = self.options.mixed_precision.enabled - ort_parameters.world_rank = self.options.distributed.world_rank - ort_parameters.world_size = self.options.distributed.world_size - ort_parameters.gradient_accumulation_steps = self.options.batch.gradient_accumulation_steps - ort_parameters.allreduce_post_accumulation = self.options.distributed.allreduce_post_accumulation - ort_parameters.enable_adasum = self.options.distributed.enable_adasum - ort_parameters.deepspeed_zero_stage = self.options.distributed.deepspeed_zero_optimization.stage - ort_parameters.enable_grad_norm_clip = self.options.utils.grad_norm_clip - ort_parameters.set_gradients_as_graph_outputs = False - ort_parameters.use_memory_efficient_gradient = self.options.utils.memory_efficient_gradient - ort_parameters.training_optimizer_name = self.optim_config.name - ort_parameters.lr_params_feed_name = self.model_desc.learning_rate.name - ort_parameters.weights_to_train = trainable_params - ort_parameters.optimizer_attributes_map = optimizer_attributes_map - ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map - if bool(optimizer_state_dict): - ort_parameters.set_optimizer_initial_state(optimizer_state_dict) - - ort_parameters.attn_dropout_recompute = self.options.graph_transformer.attn_dropout_recompute - ort_parameters.gelu_recompute = self.options.graph_transformer.gelu_recompute - ort_parameters.transformer_layer_recompute = self.options.graph_transformer.transformer_layer_recompute - ort_parameters.number_recompute_layers = self.options.graph_transformer.number_recompute_layers - - ort_parameters.data_parallel_size = self.options.distributed.data_parallel_size - ort_parameters.horizontal_parallel_size = self.options.distributed.horizontal_parallel_size - ort_parameters.pipeline_parallel_size = self.options.distributed.pipeline_parallel.pipeline_parallel_size - ort_parameters.num_pipeline_micro_batches = ( - self.options.distributed.pipeline_parallel.num_pipeline_micro_batches - ) - ort_parameters.pipeline_cut_info_string = self.options.distributed.pipeline_parallel.pipeline_cut_info_string - # We have special handling for dictionary-typed option. - # sliced_schema._validated_opts is the original dictionary while sliced_schema is a _ORTTrainerOptionsInternal. - ort_parameters.sliced_schema = self.options.distributed.pipeline_parallel.sliced_schema._validated_opts - # We have special handling for dictionary-typed option. - # sliced_axes._validated_opts is the original dictionary while sliced_schema is a _ORTTrainerOptionsInternal. - ort_parameters.sliced_axes = self.options.distributed.pipeline_parallel.sliced_axes._validated_opts - ort_parameters.sliced_tensor_names = self.options.distributed.pipeline_parallel.sliced_tensor_names - - ort_parameters.model_after_graph_transforms_path = ( - self.options.debug.graph_save_paths.model_after_graph_transforms_path - ) - ort_parameters.model_with_gradient_graph_path = ( - self.options.debug.graph_save_paths.model_with_gradient_graph_path - ) - ort_parameters.model_with_training_graph_path = ( - self.options.debug.graph_save_paths.model_with_training_graph_path - ) - - # SessionOptions - session_options = ort.SessionOptions() if session_options is None else session_options - session_options.use_deterministic_compute = self.options.debug.deterministic_compute - if ( - self.options.graph_transformer.attn_dropout_recompute - or self.options.graph_transformer.gelu_recompute - or self.options.graph_transformer.transformer_layer_recompute - ): - session_options.execution_order = ort.ExecutionOrder.PRIORITY_BASED - if len(self.options.debug.graph_save_paths.model_with_training_graph_after_optimization_path) > 0: - session_options.optimized_model_filepath = ( - self.options.debug.graph_save_paths.model_with_training_graph_after_optimization_path - ) - - # old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error. - # for example, load_state_dict will be called before returing the function, and it calls _init_session again - del self._training_session - - # Set provider-specific options if needed - def get_providers(provider_options): - providers = ort.get_available_providers() - if provider_options: - for provider_name in provider_options: - if provider_name in providers: - providers[providers.index(provider_name)] = (provider_name, provider_options[provider_name]) - else: - providers.insert(0, (provider_name, provider_options[provider_name])) - # default: using cuda - elif "cuda" in self.options.device.id.lower(): - gpu_ep_options = {"device_id": _utils.get_device_index(self.options.device.id)} - gpu_ep_name = "ROCMExecutionProvider" if self.is_rocm_pytorch else "CUDAExecutionProvider" - if self.options.device.mem_limit > 0: - gpu_ep_options["gpu_mem_limit"] = self.options.device.mem_limit - - if gpu_ep_name not in providers: - raise RuntimeError( - "ORTTrainer options specify a CUDA device but the {} provider is unavailable.".format( - cuda_ep_name # noqa: F821 - ) - ) - - providers[providers.index(gpu_ep_name)] = (gpu_ep_name, gpu_ep_options) - - return providers - - # TrainingSession - self._training_session = ort.TrainingSession( - self._onnx_model.SerializeToString(), ort_parameters, session_options, get_providers(provider_options) - ) - - # I/O bindings - self._train_io_binding = self._training_session.io_binding() - self._eval_io_binding = self._training_session.io_binding() - - def _init_onnx_model(self, inputs): - if self._onnx_model is not None: - return - - if self._torch_model is not None: - # PyTorch model is moved to cpu to save GPU memory - self._torch_model.cpu() - - # PyTorch buffers (created using 'register_buffer') shouldn't be trained - torch_buffers = list(dict(self._torch_model.named_buffers()).keys()) - self.options.utils.frozen_weights.extend(torch_buffers) - - # Export to ONNX - self._onnx_model = self._convert_torch_model_loss_fn_to_onnx(inputs, "cpu") - - # Post processing for ONNX models expported from PyTorch - if self.options._internal_use.enable_internal_postprocess: - self._onnx_model = postprocess.run_postprocess(self._onnx_model) - if self.options._internal_use.extra_postprocess: - self._onnx_model = self.options._internal_use.extra_postprocess(self._onnx_model) - - optimizer_state_dict = {} - if self._load_state_dict: - optimizer_state_dict = self._load_state_dict() - - self._init_session( - optimizer_state_dict, - session_options=self.options.session_options, - provider_options=self.options._validated_opts["provider_options"], - ) - - def _init_session(self, optimizer_state_dict={}, session_options=None, provider_options=None): # noqa: B006 - if self._onnx_model is None: - return - - if self.options.utils.run_symbolic_shape_infer: - self._onnx_model = SymbolicShapeInference.infer_shapes( - self._onnx_model, auto_merge=True, guess_output_rank=True - ) - - # Create training session used by train_step - # pass all optimizer states to the backend - self._create_ort_training_session( - optimizer_state_dict, session_options=session_options, provider_options=provider_options - ) - - # Update model description to update dtype when mixed precision is enabled - # C++ backend modifies model's output dtype from float32 to float16 for mixed precision - # Note that for training we must use float32 and for evaluation we must use float16 - for idx, o_desc in enumerate(self.model_desc.outputs): - if ( - self.options.mixed_precision.enabled - and o_desc.dtype == torch.float32 - and not self._training_session.is_output_fp32_node(o_desc.name) - ): - self.model_desc.add_type_to_output_description(idx, o_desc.dtype, torch.float16) - - # Update model description - self._model_desc_inputs_with_lr = [*self.model_desc.inputs, self.model_desc.learning_rate] - - # Update Mixed Precision, if applicable - if self.options.mixed_precision.enabled: - self.model_desc.loss_scale_input = self._training_session.loss_scale_input_name - self._model_desc_inputs_with_lr_and_loss_scale = [ - *self._model_desc_inputs_with_lr, - self.model_desc.loss_scale_input, - ] - self.model_desc.all_finite = _utils.get_all_gradients_finite_name_from_session(self._training_session) - self._model_desc_outputs_with_all_finite = [*self.model_desc.outputs, self.model_desc.all_finite] - elif self.options.mixed_precision.loss_scaler: - raise ValueError("Loss Scaler cannot be specified when Mixed Precision is not enabled") - - # Update Loss Scaler Input Name, if applicable - if self.options.mixed_precision.enabled and self.options.mixed_precision.loss_scaler: - self.options.mixed_precision.loss_scaler.input_name = self.model_desc.loss_scale_input.name - elif not self.options.mixed_precision.enabled and self.options.mixed_precision.loss_scaler: - raise ValueError("Loss Scaler cannot be specified when Mixed Precision is not enabled") - - # Update Gradient Accumulation, if applicable - if self.options.batch.gradient_accumulation_steps > 1: - self.model_desc.gradient_accumulation = _utils.get_gradient_accumulation_name_from_session( - self._training_session - ) - self._model_desc_outputs_with_gradient_accumulation = [ - *self.model_desc.outputs, - self.model_desc.gradient_accumulation, - ] - - # TODO: Remove when experimental checkpoint functions are removed - if self._state_dict: - checkpoint.experimental_load_state_dict(self, self._state_dict, self._load_state_dict_strict) - self._state_dict_debug = self._state_dict - self._state_dict = {} - - def _prepare_model_input(self, inputs_desc, lr, loss_scale, *inputs, **kwargs): - # Normalize input to tuple of samples - if type(inputs) == tuple and len(inputs) == 1 and type(inputs[0]) == list: # noqa: E721 - input = tuple(inputs[0]) - else: - input = inputs - - # Append input from 'kwargs' - for input_desc in inputs_desc: - if input_desc.name in kwargs: - input = (*input, kwargs[input_desc.name]) - - # Append learning rate - extra_inputs = 0 - if lr is not None: - lr = torch.tensor([lr]) - input += (lr,) - extra_inputs += 1 - - # Append loss scale - if loss_scale is not None: - assert self.options.mixed_precision.enabled, "Loss scale cannot be used without mixed precision" - loss_scale = torch.tensor([loss_scale]) - input += (loss_scale,) - extra_inputs += 1 - - # Only assert length of input when fetches is not used - assert self._train_step_info.fetches or len(self.model_desc.inputs) + extra_inputs == len(input) - return input - - def _resolve_symbolic_dimensions(self, inputs, inputs_desc, outputs_desc): - outputs = copy.deepcopy(outputs_desc) - resolved_dims = {} - for input, i_desc in zip(inputs, inputs_desc): - for i_idx, i_axis in enumerate(i_desc.shape): - if isinstance(i_axis, str): - if i_axis not in resolved_dims: - resolved_dims[i_axis] = input.size()[i_idx] - else: - assert resolved_dims[i_axis] == input.size()[i_idx], f"Mismatch in dynamic shape {i_axis}" - - for o_desc in outputs: - for idx_o, o_axis in enumerate(o_desc.shape): - if isinstance(o_axis, str): - o_desc.shape[idx_o] = resolved_dims[o_axis] - - unknown_dim = [o_desc.name for dim in o_desc.shape for o_desc in outputs if isinstance(dim, str)] - if unknown_dim: - raise RuntimeError(f"Cannot execute model with unknown output dimensions ({unknown_dim}") - - return outputs - - def _training_session_run_helper(self, is_train, inputs, inputs_desc, outputs_desc, run_options=None): - # Select IO binding - if is_train: - iobinding = self._train_io_binding - else: - iobinding = self._eval_io_binding - - # Get the list of the actual session inputs because unused inputs can be removed. - input_nodes = self._training_session.get_inputs() - input_node_names = [input_node.name for input_node in input_nodes] - - # Bind input tensors - for input, input_desc in zip(inputs, inputs_desc): - if input_desc.name in input_node_names: - device_index = _utils.get_device_index_from_input(input) - iobinding.bind_input( - input_desc.name, - input.device.type, - device_index, - _utils.dtype_torch_to_numpy(input.dtype), - list(input.size()), - input.data_ptr(), - ) - - # Bind output tensors - outputs_desc_resolved = self._resolve_symbolic_dimensions(inputs, inputs_desc, outputs_desc) - result = {} - for output_desc in outputs_desc_resolved: - target_device = self.options.device.id - if self.options.mixed_precision.enabled and output_desc.name == self.model_desc.all_finite.name: - # Keep all finite flag on CPU to match backend implementation - # This prevents CPU -> GPU -> CPU copies between frontend and backend - target_device = "cpu" - # the self.options.device may be a device that pytorch does not recognize. - # in that case, we temporary prefer to leave the input/output on CPU and let ORT session - # to move the data between device and host. - # so output will be on the same device as input. - try: - torch.device(target_device) - except Exception: - # in this case, input/output must on CPU - assert input.device.type == "cpu" - target_device = "cpu" - - torch_tensor = torch.zeros( - output_desc.shape, - device=target_device, - dtype=output_desc.dtype_amp if output_desc.dtype_amp else output_desc.dtype, - ) - iobinding.bind_output( - output_desc.name, - torch_tensor.device.type, - _utils.get_device_index(target_device), - _utils.dtype_torch_to_numpy(torch_tensor.dtype), - list(torch_tensor.size()), - torch_tensor.data_ptr(), - ) - result[output_desc.name] = torch_tensor - - # Run a train/eval step - self._training_session.run_with_iobinding(iobinding, run_options) - return result - - def _update_onnx_model_initializers(self, state_tensors): - r"""Updates ONNX graph initializers with state_tensors's values - - Usually called to save or load an ONNX model. - - The tensors names of state_tensors are compared to all ONNX initializer tensors - and when the name matches, the ONNX graph is updated with the new value. - """ - assert isinstance(state_tensors, dict), "state_tensors must be a dict" - - new_weights = [] - replace_indices = [] - for i, w in enumerate(self._onnx_model.graph.initializer): - if w.name in state_tensors: - new_weights.append(onnx.numpy_helper.from_array(state_tensors[w.name], w.name)) - replace_indices.append(i) - replace_indices.sort(reverse=True) - for w_i in replace_indices: - del self._onnx_model.graph.initializer[w_i] - self._onnx_model.graph.initializer.extend(new_weights) - - def _extract_model_states(self, state_dict, pytorch_format): - """Extract model states from the training session and load into the state_dict""" - - model_states = self._training_session.get_model_state(include_mixed_precision_weights=False) - state_dict[_utils.state_dict_model_key()] = {} - - # extract trained model weights from the training session - for precision in model_states: - state_dict[_utils.state_dict_model_key()][precision] = {} - for model_state_key in model_states[precision]: - if pytorch_format: - state_dict[_utils.state_dict_model_key()][precision][model_state_key] = torch.from_numpy( - model_states[precision][model_state_key] - ) - else: - state_dict[_utils.state_dict_model_key()][precision][model_state_key] = model_states[precision][ - model_state_key - ] - - # extract untrained (frozen) model weights - for node in self._onnx_model.graph.initializer: - if ( - node.name not in state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()] - and node.name in self.options.utils.frozen_weights - ): - if pytorch_format: - state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()][ - node.name - ] = torch.from_numpy(onnx.numpy_helper.to_array(node)) - else: - state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()][ - node.name - ] = onnx.numpy_helper.to_array(node) - - def _extract_trainer_options(self, state_dict): - """Extract relevant trainer configuration and load it into the state_dict""" - - mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key() - zero_stage = _utils.state_dict_trainer_options_zero_stage_key() - world_rank = _utils.state_dict_trainer_options_world_rank_key() - world_size = _utils.state_dict_trainer_options_world_size_key() - optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key() - D_size = _utils.state_dict_trainer_options_data_parallel_size_key() # noqa: N806 - H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() # noqa: N806 - - state_dict[_utils.state_dict_trainer_options_key()] = {} - state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] = self.options.mixed_precision.enabled - state_dict[_utils.state_dict_trainer_options_key()][ - zero_stage - ] = self.options.distributed.deepspeed_zero_optimization.stage - state_dict[_utils.state_dict_trainer_options_key()][world_rank] = self.options.distributed.world_rank - state_dict[_utils.state_dict_trainer_options_key()][world_size] = self.options.distributed.world_size - state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] = self.optim_config.name - state_dict[_utils.state_dict_trainer_options_key()][D_size] = self.options.distributed.data_parallel_size - state_dict[_utils.state_dict_trainer_options_key()][H_size] = self.options.distributed.horizontal_parallel_size - - def _extract_train_step_info(self, state_dict): - """Extract train step info settings and save it into the state_dict""" - - optimization_step = _utils.state_dict_train_step_info_optimization_step_key() - step = _utils.state_dict_train_step_info_step_key() - - state_dict[_utils.state_dict_train_step_info_key()] = {} - state_dict[_utils.state_dict_train_step_info_key()][optimization_step] = self._train_step_info.optimization_step - state_dict[_utils.state_dict_train_step_info_key()][step] = self._train_step_info.step - - def state_dict(self, pytorch_format=False): - """Returns a dictionary with model, train step info and optionally, optimizer states - - The returned dictionary contains the following information: - - Model and optimizer states - - Required ORTTrainerOptions settings - - Distributed training information, such as but not limited to ZeRO - - Train step info settings - - Structure of the returned dictionary: - - When `pytorch_format = False` - schema: - { - "model": - { - type: dict, - schema: - { - "full_precision": - { - type: dict, - schema: - { - model_weight_name: - { - type: array - } - } - } - } - }, - "optimizer": - { - type: dict, - schema: - { - model_weight_name: - { - type: dict, - schema: - { - "Moment_1": - { - type: array - }, - "Moment_2": - { - type: array - }, - "Update_Count": - { - type: array, - optional: True # present if optimizer is adam, absent otherwise - } - } - }, - "shared_optimizer_state": - { - type: dict, - optional: True, # present optimizer is shared, absent otherwise. - schema: - { - "step": - { - type: array, - } - } - } - } - }, - "trainer_options": - { - type: dict, - schema: - { - "mixed_precision": - { - type: bool - }, - "zero_stage": - { - type: int - }, - "world_rank": - { - type: int - }, - "world_size": - { - type: int - }, - "optimizer_name": - { - type: str - }, - "data_parallel_size": - { - type: int - }, - "horizontal_parallel_size": - { - type: int - } - } - }, - "partition_info": - { - type: dict, - optional: True, # present if states partitioned, else absent - schema: - { - model_weight_name: - { - type: dict, - schema: - { - "original_dim": - { - type: array - }, - "megatron_row_partition": - { - type: int - } - } - } - } - }, - "train_step_info": - { - type: dict, - schema: - { - "optimization_step": - { - type: int - }, - "step": - { - type: int - } - } - } - } - - When `pytorch_format = True` - schema: - { - model_weight_name: - { - type: tensor - } - } - - Args: - pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema - - Returns: - A dictionary with `ORTTrainer` state - """ - if not self._training_session: - warnings.warn( - "ONNX Runtime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling ORTTrainer.state_dict().", - UserWarning, - ) - return self._load_state_dict.args[0] if self._load_state_dict else {} - - state_dict = {} - - # load training session model states into the state_dict - self._extract_model_states(state_dict, pytorch_format) - if pytorch_format: - if self.options.distributed.deepspeed_zero_optimization.stage > 0: - warnings.warn("Incomplete state_dict: ZeRO enabled", UserWarning) - if self.options.distributed.horizontal_parallel_size > 1: - warnings.warn("Incomplete state_dict: Megatron enabled", UserWarning) - # if pytorch_format is true, return a flat dictionary with only model states - # which is compatible with a PyTorch model - return state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()] - - # load training session optimizer states into the state_dict - state_dict[_utils.state_dict_optimizer_key()] = self._training_session.get_optimizer_state() - - # extract the relevant training configuration from the trainer and load them into the state_dict - self._extract_trainer_options(state_dict) - - # Extract train step info settings and load it into the state_dict - self._extract_train_step_info(state_dict) - - # add partition information in case of a distributed run - if ( - self.options.distributed.deepspeed_zero_optimization.stage > 0 - or self.options.distributed.horizontal_parallel_size > 1 - ): - state_dict[_utils.state_dict_partition_info_key()] = self._training_session.get_partition_info_map() - - return state_dict - - def _load_model_states(self, state_dict, strict): - """Load the model states onto the onnx model graph""" - - if _utils.state_dict_model_key() not in state_dict: - return - - # collect all initializer names from the current onnx graph - assert self._onnx_model, "ONNX model graph is not exported" - initializer_names = {node.name for node in self._onnx_model.graph.initializer} - - # loaded_initializers dict will be loaded with all the model states from the state dictionary - # that are found in the initializer_names dictionary - loaded_initializers = {} - - # copy over model states from the input state dict onto the onnx model - for precision, precision_states in state_dict[_utils.state_dict_model_key()].items(): - for state_key, state_value in precision_states.items(): - if state_key in initializer_names: - loaded_initializers[state_key] = state_value - elif strict: - raise RuntimeError(f"Unexpected key: {state_key} in state_dict[model][{precision}]") - - # update onnx model from loaded initializers - self._update_onnx_model_initializers(loaded_initializers) - - def _load_optimizer_states(self, current_state_dict, state_dict): - """Load the optimizer states onto the training session state dictionary""" - - def _check_optimizer_mismatch(state_dict): - """Assert that the loaded optimizer has the same config as the current training session config""" - - # the state_dict optimizer_name can be a byte string (if coming from checkpoint file) - # or can be a regular string (coming from user) - optimizer_name = state_dict[_utils.state_dict_trainer_options_key()][ - _utils.state_dict_trainer_options_optimizer_name_key() - ] - - # optimizer_name can be either a regular string or a byte string. - # if it is a byte string, convert to regular string using decode() - # if it is a regular string, do nothing to it - try: # noqa: SIM105 - optimizer_name = optimizer_name.decode() - except AttributeError: - pass - assert self.optim_config.name == optimizer_name, "Optimizer mismatch: expected {}, got {}".format( - self.optim_config.name, optimizer_name - ) - - if _utils.state_dict_optimizer_key() not in state_dict: - return - - # check optimizer config names are the same for current session and the sessino being loaded - _check_optimizer_mismatch(state_dict) - - # create an entry for the optimizer in the training session state dictionary - if _utils.state_dict_optimizer_key() not in current_state_dict: - current_state_dict[_utils.state_dict_optimizer_key()] = {} - - # copy over optimizer states from the input state dict onto the training session state dict - for model_state_key, optimizer_dict in state_dict[_utils.state_dict_optimizer_key()].items(): - if model_state_key not in current_state_dict[_utils.state_dict_optimizer_key()]: - current_state_dict[_utils.state_dict_optimizer_key()][model_state_key] = {} - for optimizer_state_key, optimizer_state_value in optimizer_dict.items(): - current_state_dict[_utils.state_dict_optimizer_key()][model_state_key][ - optimizer_state_key - ] = optimizer_state_value - - def _load_state_dict_impl(self, state_dict, strict=True): - """Load the state dictionary onto the onnx model and on the training session graph""" - - # clear the callable partial - self._load_state_dict = None - - def _mismatch_keys(keys1, keys2, in_error_str, allow_unexpected=False): - """Find out the missing and the unexpected keys in two dictionaries - - Throws a runtime error if missing or unexpected keys are found - - Keys in keys1 not in keys2 will be marked as missing - - Keys in keys2 not in keys1 will be marked as unexpected - """ - keys1 = set(keys1) - keys2 = set(keys2) - missing_keys = list(keys1 - keys2) - unexpected_keys = list(keys2 - keys1) - if len(missing_keys) > 0: - raise RuntimeError(f"Missing keys: {missing_keys} in {in_error_str}") - if len(unexpected_keys) > 0 and not allow_unexpected: - raise RuntimeError(f"Unexpected keys: {unexpected_keys} in {in_error_str}") - - def _check_model_key_mismatch(current_state_dict, state_dict, allow_unexpected=False): - """Check if there is any mismatch in the model sub state dictionary between the two state_dicts""" - - # check unxexpected and missing precision keys in the model state_dict compared to the training - # session model state_dict - _mismatch_keys( - current_state_dict[_utils.state_dict_model_key()], - state_dict[_utils.state_dict_model_key()], - "state_dict[model]", - allow_unexpected, - ) - - # check for model state key mismatch - for precision_key in current_state_dict[_utils.state_dict_model_key()]: - _mismatch_keys( - current_state_dict[_utils.state_dict_model_key()][precision_key], - state_dict[_utils.state_dict_model_key()][precision_key], - f"state_dict[model][{precision_key}]", - allow_unexpected, - ) - - def _check_optimizer_key_mismatch(current_state_dict, state_dict, allow_unexpected=False): - """Check if there is any mismatch in the optimizer sub state dictionary between the two state_dicts""" - - # check for model state key mismatch for the optimizer state_dict - _mismatch_keys( - current_state_dict[_utils.state_dict_optimizer_key()], - state_dict[_utils.state_dict_optimizer_key()], - "state_dict[optimizer]", - allow_unexpected, - ) - - # check for optimizer state keys mismatch - for model_state_key in current_state_dict[_utils.state_dict_optimizer_key()]: - _mismatch_keys( - current_state_dict[_utils.state_dict_optimizer_key()][model_state_key], - state_dict[_utils.state_dict_optimizer_key()][model_state_key], - f"state_dict[optimizer][{model_state_key}]", - allow_unexpected, - ) - - def _check_key_mismatch(current_state_dict, state_dict, allow_unexpected=False): - """Check if there is a mismatch in the keys (model and optimizer) in the two state_dicts""" - - # check presence of 'model' in the input state_dict - if _utils.state_dict_model_key() in state_dict: - _check_model_key_mismatch(current_state_dict, state_dict, allow_unexpected) - else: - warnings.warn("Missing key: model in state_dict", UserWarning) - # check presence of 'optimizer' in the input state_dict - if _utils.state_dict_optimizer_key() in state_dict: - _check_optimizer_key_mismatch(current_state_dict, state_dict, allow_unexpected) - else: - warnings.warn("Missing key: optimizer in state_dict", UserWarning) - - # extract state dict from the current training session. this is to persist the states between - # two training sessions. - # for example, if user provided only the model states, the optimizer states from the current - # training session must be persisted - current_state_dict = {} - if self._training_session: - current_state_dict = self.state_dict() - if strict: - # for Zero enabled, the current trainer might not have the complete state, and we must allow - # extra keys to be present in the state dict - allow_unexpected = self.options.distributed.deepspeed_zero_optimization.stage > 0 - _check_key_mismatch(current_state_dict, state_dict, allow_unexpected) - - # load the model states from the input state dictionary into the onnx graph - self._load_model_states(state_dict, strict) - - # load the optimizer states from the input state dictionary into the training session states - # dictionary - self._load_optimizer_states(current_state_dict, state_dict) - - return ( - current_state_dict[_utils.state_dict_optimizer_key()] - if _utils.state_dict_optimizer_key() in current_state_dict - else {} - ) - - def _load_train_step_info(self, state_dict): - """Load the train step info settings from state dict""" - - if _utils.state_dict_train_step_info_key() not in state_dict: - warnings.warn("Missing key: train_step_info in state_dict", UserWarning) - return - - optimization_step = _utils.state_dict_train_step_info_optimization_step_key() - step = _utils.state_dict_train_step_info_step_key() - - self._train_step_info.optimization_step = state_dict[_utils.state_dict_train_step_info_key()][optimization_step] - self._train_step_info.step = state_dict[_utils.state_dict_train_step_info_key()][step] - - def load_state_dict(self, state_dict, strict=True): - """Loads state_dict containing model/optimizer states into ORTTrainer - - The state_dict dictionary may contain the following information: - - Model and optimizer states - - Required ORTTrainerOptions settings - - Distributed training information, such as but not limited to ZeRO - - Args: - state_dict: state dictionary containing both model and optimizer states. The structure of this dictionary - should be the same as the one that is returned by ORTTrainer.state_dict for the case when pytorch_format=False - strict: boolean flag to strictly enforce that the input state_dict keys match the keys from ORTTrainer.state_dict - """ - - # if onnx graph has not been initialized, loading of states will be put on hold. - # a copy of the state_dict and other arguments to the function will be stored until the onnx graph has - # been initialized. Once the graph is initialized, the desired states will be loaded onto the grpah - if not self._training_session: - self._load_state_dict = partial(self._load_state_dict_impl, state_dict, strict=strict) - return - - # load the train step info settings - self._load_train_step_info(state_dict) - - # load states onto the frontend onnx graph - optimizer_state_dict = self._load_state_dict_impl(state_dict, strict=strict) - - # create a new training session after loading initializer states onto the onnx graph - # pass the populated states to the training session to populate the backend graph - self._init_session( - optimizer_state_dict, - session_options=self.options.session_options, - provider_options=self.options._validated_opts["provider_options"], - ) - - def save_checkpoint(self, path, user_dict={}, include_optimizer_states=True): # noqa: B006 - """Persists ORTTrainer state dictionary on disk along with user_dict. - - Saves the state_dict along with the user_dict to a file specified by path. - - Args: - path: string representation to a file path or a python file-like object. - if file already exists at path, an exception is raised. - user_dict: custom data to be saved along with the state_dict. This data will be returned - to the user when load_checkpoint is called. - include_optimizer_states: boolean flag indicating whether or not to persist the optimizer states. - on load_checkpoint, only model states will be loaded if include_optimizer_states==True - """ - - # extract state_dict to be saved in the checkpoint - state_dict = self.state_dict() - - # if user_dict is provided, serialize to bytes and convert to hex string. - # this helps in loading the types as they are given by the user since hdf5 - # converts to numpy types otherwise - if bool(user_dict): - state_dict[_utils.state_dict_user_dict_key()] = _checkpoint_storage.to_serialized_hex(user_dict) - - # if include_optimizer_states is False, only save the model states in the checkpoint file - if not include_optimizer_states: - if _utils.state_dict_optimizer_key() in state_dict: - del state_dict[_utils.state_dict_optimizer_key()] - - _checkpoint_storage.save(state_dict, path) - - def _aggregation_required(self, loaded_trainer_options): - """Checks if aggregation is required for the loading the state_dict into the ORTTrainer""" - - # To load states in the backend, aggregation is required for every ZeRO - # or Megatron checkpoint - return ( - loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 - or loaded_trainer_options[_utils.state_dict_trainer_options_horizontal_parallel_size_key()] > 1 - ) - - def load_checkpoint(self, *paths, strict=True): - """Loads the saved checkpoint state dictionary into the ORTTrainer - - Reads the saved checkpoint files specified by paths from disk and loads the state dictionary - onto the ORTTrainer. - Aggregates the checkpoint files if aggregation is required. - - Args: - paths: one or more files represented as strings where the checkpoint is saved - strict: boolean flag to strictly enforce that the saved checkpoint state_dict - keys match the keys from ORTTrainer.state_dict - Returns: - dictionary that the user had saved when calling save_checkpoint - """ - state_dict = {} - - # check if aggregation is required - loaded_trainer_options = _checkpoint_storage.load(paths[0], key=_utils.state_dict_trainer_options_key()) - if self._aggregation_required(loaded_trainer_options): - # if aggregation is required, aggregation logic must be run on the saved checkpoints - state_dict = checkpoint.aggregate_checkpoints(paths, pytorch_format=False) - else: - # if aggregation is not required, there must only be a single file that needs to be loaded - assert len(paths) == 1, f"Expected number of files to load: 1, got {len(paths)}" - state_dict = _checkpoint_storage.load(paths[0]) - - # extract user dict from the saved checkpoint - user_dict = {} - if _utils.state_dict_user_dict_key() in state_dict: - user_dict = _checkpoint_storage.from_serialized_hex(state_dict[_utils.state_dict_user_dict_key()]) - del state_dict[_utils.state_dict_user_dict_key()] - - self.load_state_dict(state_dict, strict=strict) - - return user_dict diff --git a/orttraining/orttraining/python/training/orttrainer_options.py b/orttraining/orttraining/python/training/orttrainer_options.py deleted file mode 100644 index c63ac6f82c87f..0000000000000 --- a/orttraining/orttraining/python/training/orttrainer_options.py +++ /dev/null @@ -1,692 +0,0 @@ -import cerberus - -import onnxruntime as ort -from onnxruntime.capi._pybind_state import PropagateCastOpsStrategy - -from .amp import loss_scaler -from .optim import lr_scheduler - - -class ORTTrainerOptions: - r"""Settings used by ONNX Runtime training backend - - The parameters are hierarchically organized to facilitate configuration through semantic groups - that encompasses features, such as distributed training, etc. - - Input validation is performed on the input dict during instantiation to ensure - that supported parameters and values are passed in. Invalid input results - in :py:obj:`ValueError` exception with details on it. - - Args: - options (dict): contains all training options - _validate (bool, default is True): for internal use only - - Supported schema for kwargs: - - .. code-block:: python - - schema = { - 'batch' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'gradient_accumulation_steps' : { - 'type' : 'integer', - 'min' : 1, - 'default' : 1 - } - }, - }, - 'device' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'id' : { - 'type' : 'string', - 'default' : 'cuda' - }, - 'mem_limit' : { - 'type' : 'integer', - 'min' : 0, - 'default' : 0 - } - } - }, - 'distributed': { - 'type': 'dict', - 'default': {}, - 'required': False, - 'schema': { - 'world_rank': { - 'type': 'integer', - 'min': 0, - 'default': 0 - }, - 'world_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'local_rank': { - 'type': 'integer', - 'min': 0, - 'default': 0 - }, - 'data_parallel_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'horizontal_parallel_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'pipeline_parallel' : { - 'type': 'dict', - 'default': {}, - 'required': False, - 'schema': { - 'pipeline_parallel_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'num_pipeline_micro_batches': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'pipeline_cut_info_string': { - 'type': 'string', - 'default': '' - }, - 'sliced_schema': { - 'type': 'dict', - 'default': {}, - 'keysrules': {'type': 'string'}, - 'valuesrules': { - 'type': 'list', - 'schema': {'type': 'integer'} - } - }, - 'sliced_axes': { - 'type': 'dict', - 'default': {}, - 'keysrules': {'type': 'string'}, - 'valuesrules': {'type': 'integer'} - }, - 'sliced_tensor_names': { - 'type': 'list', - 'schema': {'type': 'string'}, - 'default': [] - } - } - }, - 'allreduce_post_accumulation': { - 'type': 'boolean', - 'default': False - }, - 'deepspeed_zero_optimization': { - 'type': 'dict', - 'default': {}, - 'required': False, - 'schema': { - 'stage': { - 'type': 'integer', - 'min': 0, - 'max': 1, - 'default': 0 - }, - } - }, - 'enable_adasum': { - 'type': 'boolean', - 'default': False - } - } - }, - 'lr_scheduler' : { - 'type' : 'optim.lr_scheduler', - 'nullable' : True, - 'default' : None - }, - 'mixed_precision' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'enabled' : { - 'type' : 'boolean', - 'default' : False - }, - 'loss_scaler' : { - 'type' : 'amp.loss_scaler', - 'nullable' : True, - 'default' : None - } - } - }, - 'graph_transformer': { - 'type': 'dict', - 'required': False, - 'default': {}, - 'schema': { - 'attn_dropout_recompute': { - 'type': 'boolean', - 'default': False - }, - 'gelu_recompute': { - 'type': 'boolean', - 'default': False - }, - 'transformer_layer_recompute': { - 'type': 'boolean', - 'default': False - }, - 'number_recompute_layers': { - 'type': 'integer', - 'min': 0, - 'default': 0 - }, - 'propagate_cast_ops_config': { - 'type': 'dict', - 'required': False, - 'default': {}, - 'schema': { - 'propagate_cast_ops_strategy': { - 'type': 'onnxruntime.training.PropagateCastOpsStrategy', - 'default': PropagateCastOpsStrategy.FLOOD_FILL - }, - 'propagate_cast_ops_level': { - 'type': 'integer', - 'default': 1 - }, - 'propagate_cast_ops_allow': { - 'type': 'list', - 'schema': {'type': 'string'}, - 'default': [] - } - } - } - } - }, - 'utils' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'frozen_weights' : { - 'type' : 'list', - 'default' : [] - }, - 'grad_norm_clip' : { - 'type' : 'boolean', - 'default' : True - }, - 'memory_efficient_gradient' : { - 'type' : 'boolean', - 'default' : False - }, - 'run_symbolic_shape_infer' : { - 'type' : 'boolean', - 'default' : False - } - } - }, - 'debug' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'deterministic_compute' : { - 'type' : 'boolean', - 'default' : False - }, - 'check_model_export' : { - 'type' : 'boolean', - 'default' : False - }, - 'graph_save_paths' : { - 'type' : 'dict', - 'default': {}, - 'required': False, - 'schema': { - 'model_after_graph_transforms_path': { - 'type': 'string', - 'default': '' - }, - 'model_with_gradient_graph_path':{ - 'type': 'string', - 'default': '' - }, - 'model_with_training_graph_path': { - 'type': 'string', - 'default': '' - }, - 'model_with_training_graph_after_optimization_path': { - 'type': 'string', - 'default': '' - }, - } - }, - } - }, - '_internal_use' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'enable_internal_postprocess' : { - 'type' : 'boolean', - 'default' : True - }, - 'extra_postprocess' : { - 'type' : 'callable', - 'nullable' : True, - 'default' : None - }, - 'onnx_opset_version': { - 'type': 'integer', - 'min' : 12, - 'max' :14, - 'default': 14 - }, - 'enable_onnx_contrib_ops' : { - 'type' : 'boolean', - 'default' : True - } - } - }, - 'provider_options':{ - 'type': 'dict', - 'default': {}, - 'required': False, - 'schema': {} - }, - 'session_options': { - 'type': 'SessionOptions', - 'nullable': True, - 'default': None - }, - } - - Keyword arguments: - batch (dict): - batch related settings - batch.gradient_accumulation_steps (int, default is 1): - number of steps to accumulate before do collective gradient reduction - device (dict): - compute device related settings - device.id (string, default is 'cuda'): - device to run training - device.mem_limit (int): - maximum memory size (in bytes) used by device.id - distributed (dict): - distributed training options. - distributed.world_rank (int, default is 0): - rank ID used for data/horizontal parallelism - distributed.world_size (int, default is 1): - number of ranks participating in parallelism - distributed.data_parallel_size (int, default is 1): - number of ranks participating in data parallelism - distributed.horizontal_parallel_size (int, default is 1): - number of ranks participating in horizontal parallelism - distributed.pipeline_parallel (dict): - Options which are only useful to pipeline parallel. - distributed.pipeline_parallel.pipeline_parallel_size (int, default is 1): - number of ranks participating in pipeline parallelism - distributed.pipeline_parallel.num_pipeline_micro_batches (int, default is 1): - number of micro-batches. We divide input batch into micro-batches and run the graph. - distributed.pipeline_parallel.pipeline_cut_info_string (string, default is ''): - string of cutting ids for pipeline partition. - distributed.allreduce_post_accumulation (bool, default is False): - True enables overlap of AllReduce with computation, while False, - postpone AllReduce until all gradients are ready - distributed.deepspeed_zero_optimization: - DeepSpeed ZeRO options. - distributed.deepspeed_zero_optimization.stage (int, default is 0): - select which stage of DeepSpeed ZeRO to use. Stage 0 means disabled. - distributed.enable_adasum (bool, default is False): - enable `Adasum `_ - algorithm for AllReduce - lr_scheduler (optim._LRScheduler, default is None): - specifies learning rate scheduler - mixed_precision (dict): - mixed precision training options - mixed_precision.enabled (bool, default is False): - enable mixed precision (fp16) - mixed_precision.loss_scaler (amp.LossScaler, default is None): - specifies a loss scaler to be used for fp16. If not specified, - :py:class:`.DynamicLossScaler` is used with default values. - Users can also instantiate :py:class:`.DynamicLossScaler` and - override its parameters. Lastly, a completely new implementation - can be specified by extending :py:class:`.LossScaler` class from scratch - graph_transformer (dict): - graph transformer related configurations - graph_transformer.attn_dropout_recompute(bool, default False) - graph_transformer.gelu_recompute(bool, default False) - graph_transformer.transformer_layer_recompute(bool, default False) - graph_transformer.number_recompute_layers(bool, default False) - graph_transformer.propagate_cast_ops_config (dict): - graph_transformer.propagate_cast_ops_config.strategy(PropagateCastOpsStrategy, default FLOOD_FILL) - Specify the choice of the cast propagation optimization strategy, either, NONE, INSERT_AND_REDUCE or FLOOD_FILL. - NONE strategy does not perform any cast propagation transformation on the graph, although other optimizations - locally change cast operations, for example, in order to fuse Transpose and MatMul nodes, the TransposeMatMulFunsion optimization could - interchange Transpose and Cast if the Cast node exists between Transpose and MatMul. - INSERT_AND_REDUCE strategy inserts and reduces cast operations around the nodes with allowed opcodes. - FLOOD_FILL strategy expands float16 regions in the graph using the allowed opcodes, and unlike - INSERT_AND_REDUCE does not touch opcodes outside expanded float16 region. - graph_transformer.propagate_cast_ops_config.level(integer, default 1) - Optimize by moving Cast operations if propagate_cast_ops_level is non-negative. - Use predetermined list of opcodes considered safe to move before/after cast operation - if propagate_cast_ops_level is positive and use propagate_cast_ops_allow otherwise. - graph_transformer.propagate_cast_ops_config.allow(list of str, []) - List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero. - attn_dropout_recompute (bool, default is False): - enable recomputing attention dropout to save memory - gelu_recompute (bool, default is False): - enable recomputing Gelu activation output to save memory - transformer_layer_recompute (bool, default is False): - enable recomputing transformer layerwise to save memory - number_recompute_layers (int, default is 0) - number of layers to apply transformer_layer_recompute, by default system will - apply recompute to all the layers, except for the last one - utils (dict): - miscellaneous options - utils.frozen_weights (list of str, []): - list of model parameter names to skip training (weights don't change) - utils.grad_norm_clip (bool, default is True): - enables gradient norm clipping for 'AdamOptimizer' and 'LambOptimizer' - utils.memory_efficient_gradient (bool, default is False): - enables use of memory aware gradient builder. - utils.run_symbolic_shape_infer (bool, default is False): - runs symbolic shape inference on the model - debug (dict): - debug options - debug.deterministic_compute (bool, default is False) - forces compute to be deterministic accross runs - debug.check_model_export (bool, default is False) - compares PyTorch model outputs with ONNX model outputs in inference before the first - train step to ensure successful model export - debug.graph_save_paths (dict): - paths used for dumping ONNX graphs for debugging purposes - debug.graph_save_paths.model_after_graph_transforms_path (str, default is "") - path to export the ONNX graph after training-related graph transforms have been applied. - No output when it is empty. - debug.graph_save_paths.model_with_gradient_graph_path (str, default is "") - path to export the ONNX graph with the gradient graph added. No output when it is empty. - debug.graph_save_paths.model_with_training_graph_path (str, default is "") - path to export the training ONNX graph with forward, gradient and optimizer nodes. - No output when it is empty. - debug.graph_save_paths.model_with_training_graph_after_optimization_path (str, default is "") - outputs the optimized training graph to the path if nonempty. - _internal_use (dict): - internal options, possibly undocumented, that might be removed without notice - _internal_use.enable_internal_postprocess (bool, default is True): - enable internal internal post processing of the ONNX model - _internal_use.extra_postprocess (callable, default is None) - a functor to postprocess the ONNX model and return a new ONNX model. - It does not override :py:attr:`._internal_use.enable_internal_postprocess`, but complement it - _internal_use.onnx_opset_version (int, default is 14): - ONNX opset version used during model exporting. - _internal_use.enable_onnx_contrib_ops (bool, default is True) - enable PyTorch to export nodes as contrib ops in ONNX. - This flag may be removed anytime in the future. - session_options (onnxruntime.SessionOptions): - The SessionOptions instance that TrainingSession will use. - provider_options (dict): - The provider_options for customized execution providers. it is dict map from EP name to - a key-value pairs, like {'EP1' : {'key1' : 'val1'}, ....} - - Example: - .. code-block:: python - - opts = ORTTrainerOptions({ - 'batch' : { - 'gradient_accumulation_steps' : 128 - }, - 'device' : { - 'id' : 'cuda:0', - 'mem_limit' : 2*1024*1024*1024, - }, - 'lr_scheduler' : optim.lr_scheduler.LinearWarmupLRScheduler(), - 'mixed_precision' : { - 'enabled': True, - 'loss_scaler': amp.LossScaler(loss_scale=float(1 << 16)) - } - }) - fp16_enabled = opts.mixed_precision.enabled - """ - - def __init__(self, options={}): # noqa: B006 - # Keep a copy of original input for debug - self._original_opts = dict(options) - - # Used for logging purposes - self._main_class_name = self.__class__.__name__ - - # Validates user input - self._validated_opts = dict(self._original_opts) - validator = ORTTrainerOptionsValidator(_ORTTRAINER_OPTIONS_SCHEMA) - self._validated_opts = validator.validated(self._validated_opts) - if self._validated_opts is None: - raise ValueError(f"Invalid options: {validator.errors}") - - # Convert dict in object - for k, v in self._validated_opts.items(): - setattr(self, k, self._wrap(v)) - - def __repr__(self): - return "{%s}" % str( - ", ".join( - f"'{k}': {v!r}" - for (k, v) in self.__dict__.items() - if k not in ["_original_opts", "_validated_opts", "_main_class_name"] - ) - ) - - def _wrap(self, v): - if isinstance(v, (tuple, list, set, frozenset)): - return type(v)([self._wrap(i) for i in v]) - else: - return _ORTTrainerOptionsInternal(self._main_class_name, v) if isinstance(v, dict) else v - - -class _ORTTrainerOptionsInternal(ORTTrainerOptions): - r"""Internal class used by ONNX Runtime training backend for input validation - - NOTE: Users MUST NOT use this class in any way! - """ - - def __init__(self, main_class_name, options): - # Used for logging purposes - self._main_class_name = main_class_name - # We don't call super().__init__(options) here but still called it "_validated_opts" - # instead of "_original_opts" because it has been validated in the top-level - # ORTTrainerOptions's constructor. - self._validated_opts = dict(options) - # Convert dict in object - for k, v in dict(options).items(): - setattr(self, k, self._wrap(v)) - - -class ORTTrainerOptionsValidator(cerberus.Validator): - _LR_SCHEDULER = cerberus.TypeDefinition("lr_scheduler", (lr_scheduler._LRScheduler,), ()) - _LOSS_SCALER = cerberus.TypeDefinition("loss_scaler", (loss_scaler.LossScaler,), ()) - - _SESSION_OPTIONS = cerberus.TypeDefinition("session_options", (ort.SessionOptions,), ()) - - _PROPAGATE_CAST_OPS_STRATEGY = cerberus.TypeDefinition( - "propagate_cast_ops_strategy", (PropagateCastOpsStrategy,), () - ) - - types_mapping = cerberus.Validator.types_mapping.copy() - types_mapping["lr_scheduler"] = _LR_SCHEDULER - types_mapping["loss_scaler"] = _LOSS_SCALER - types_mapping["session_options"] = _SESSION_OPTIONS - types_mapping["propagate_cast_ops_strategy"] = _PROPAGATE_CAST_OPS_STRATEGY - - -def _check_is_callable(field, value, error): - result = False - try: - # Python 3 - result = value is None or callable(value) - except Exception: - # Python 3 but < 3.2 - if hasattr(value, "__call__"): # noqa: B004 - result = True - if not result: - error(field, "Must be callable or None") - - -_ORTTRAINER_OPTIONS_SCHEMA = { - "batch": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": {"gradient_accumulation_steps": {"type": "integer", "min": 1, "default": 1}}, - }, - "device": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "id": {"type": "string", "default": "cuda"}, - "mem_limit": {"type": "integer", "min": 0, "default": 0}, - }, - }, - "distributed": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "world_rank": {"type": "integer", "min": 0, "default": 0}, - "world_size": {"type": "integer", "min": 1, "default": 1}, - "local_rank": {"type": "integer", "min": 0, "default": 0}, - "data_parallel_size": {"type": "integer", "min": 1, "default": 1}, - "horizontal_parallel_size": {"type": "integer", "min": 1, "default": 1}, - "pipeline_parallel": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "pipeline_parallel_size": {"type": "integer", "min": 1, "default": 1}, - "num_pipeline_micro_batches": {"type": "integer", "min": 1, "default": 1}, - "pipeline_cut_info_string": {"type": "string", "default": ""}, - "sliced_schema": { - "type": "dict", - "default_setter": lambda _: {}, - "keysrules": {"type": "string"}, - "valuesrules": {"type": "list", "schema": {"type": "integer"}}, - }, - "sliced_axes": { - "type": "dict", - "default_setter": lambda _: {}, - "keysrules": {"type": "string"}, - "valuesrules": {"type": "integer"}, - }, - "sliced_tensor_names": {"type": "list", "schema": {"type": "string"}, "default": []}, - }, - }, - "allreduce_post_accumulation": {"type": "boolean", "default": False}, - "deepspeed_zero_optimization": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "stage": {"type": "integer", "min": 0, "max": 1, "default": 0}, - }, - }, - "enable_adasum": {"type": "boolean", "default": False}, - }, - }, - "lr_scheduler": {"type": "lr_scheduler", "nullable": True, "default": None}, - "mixed_precision": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "enabled": {"type": "boolean", "default": False}, - "loss_scaler": {"type": "loss_scaler", "nullable": True, "default": None}, - }, - }, - "graph_transformer": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "attn_dropout_recompute": {"type": "boolean", "default": False}, - "gelu_recompute": {"type": "boolean", "default": False}, - "transformer_layer_recompute": {"type": "boolean", "default": False}, - "number_recompute_layers": {"type": "integer", "min": 0, "default": 0}, - "propagate_cast_ops_config": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "strategy": { - "type": "propagate_cast_ops_strategy", - "nullable": True, - "default": PropagateCastOpsStrategy.FLOOD_FILL, - }, - "level": {"type": "integer", "min": -1, "default": 1}, - "allow": {"type": "list", "schema": {"type": "string"}, "default": []}, - }, - }, - }, - }, - "utils": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "frozen_weights": {"type": "list", "default": []}, - "grad_norm_clip": {"type": "boolean", "default": True}, - "memory_efficient_gradient": {"type": "boolean", "default": False}, - "run_symbolic_shape_infer": {"type": "boolean", "default": False}, - }, - }, - "debug": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "deterministic_compute": {"type": "boolean", "default": False}, - "check_model_export": {"type": "boolean", "default": False}, - "graph_save_paths": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "model_after_graph_transforms_path": {"type": "string", "default": ""}, - "model_with_gradient_graph_path": {"type": "string", "default": ""}, - "model_with_training_graph_path": {"type": "string", "default": ""}, - "model_with_training_graph_after_optimization_path": {"type": "string", "default": ""}, - }, - }, - }, - }, - "_internal_use": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "enable_internal_postprocess": {"type": "boolean", "default": True}, - "extra_postprocess": {"check_with": _check_is_callable, "nullable": True, "default": None}, - "onnx_opset_version": {"type": "integer", "min": 12, "max": 14, "default": 14}, - "enable_onnx_contrib_ops": {"type": "boolean", "default": True}, - }, - }, - "provider_options": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "allow_unknown": True, - "schema": {}, - }, - "session_options": {"type": "session_options", "nullable": True, "default": None}, -} diff --git a/orttraining/orttraining/python/training/postprocess.py b/orttraining/orttraining/python/training/postprocess.py deleted file mode 100644 index 6c2adb6af7978..0000000000000 --- a/orttraining/orttraining/python/training/postprocess.py +++ /dev/null @@ -1,478 +0,0 @@ -import os.path # noqa: F401 -import struct -import sys # noqa: F401 - -import numpy as np # noqa: F401 -import onnx -from onnx import * # noqa: F403 -from onnx import helper, numpy_helper # noqa: F401 - - -def run_postprocess(model): - # this post pass is not required for pytorch >= 1.5 - # where add_node_name in torch.onnx.export is default to True - model = add_name(model) - - # this post pass is not required for pytorch > 1.6 - model = fuse_softmaxNLL_to_softmaxCE(model) - - model = fix_expand_shape(model) - model = fix_expand_shape_pt_1_5(model) - return model - - -def find_input_node(model, arg): - result = [] - for node in model.graph.node: - for output in node.output: - if output == arg: - result.append(node) - return result[0] if len(result) == 1 else None - - -def find_output_node(model, arg): - result = [] - for node in model.graph.node: - for input in node.input: - if input == arg: - result.append(node) - return result[0] if len(result) == 1 else result - - -def add_name(model): - i = 0 - for node in model.graph.node: - node.name = "%s_%d" % (node.op_type, i) - i += 1 - return model - - -# Expand Shape PostProcess - - -def fix_expand_shape(model): - expand_nodes = [n for n in model.graph.node if n.op_type == "Expand"] - model_inputs_names = [i.name for i in model.graph.input] - - for expand_node in expand_nodes: - shape = find_input_node(model, expand_node.input[1]) - if shape.op_type == "Shape": - # an expand subgraph - # Input Input2 - # | | - # | Shape - # | | - # |__ __| - # | | - # Expand - # | - # output - # - # Only if Input2 is one of the model inputs, assign Input2's shape to output of expand. - shape_input_name = shape.input[0] - if shape_input_name in model_inputs_names: - index = model_inputs_names.index(shape_input_name) - expand_out = model.graph.value_info.add() - expand_out.name = expand_node.output[0] - expand_out.type.CopyFrom(model.graph.input[index].type) - return model - - -def fix_expand_shape_pt_1_5(model): - # expand subgraph - # Constant - # + - # ConstantOfShape - # | + | - # | + | - # (Reshape subgraph) Mul | - # |___ _________| | - # + | | | - # + Equal | - # +++++|++++++++++++++|++ - # |____________ | + - # | | + - # (subgraph) Where - # | | - # |_____ ___________| - # | | - # Expand - # | - # output - # - # where the Reshape subgraph is - # - # Input - # | | - # | |___________________ - # | | - # Shape Constant Shape Constant - # | ______| | ______| - # | | | | - # Gather Gather - # | | - # Unsqueeze Unsqueeze - # | | - # | ..Number of dims.. | - # | _________________| - # |...| - # Concat Constant - # | | - # |______ __________________| - # | | - # Reshape - # | - # output - # - # This pass will copy Input's shape to the output of Expand. - expand_nodes = [n for n in model.graph.node if n.op_type == "Expand"] - model_inputs_names = [i.name for i in model.graph.input] - - for expand_node in expand_nodes: - n_where = find_input_node(model, expand_node.input[1]) - if n_where.op_type != "Where": - continue - - n_equal = find_input_node(model, n_where.input[0]) - n_cos = find_input_node(model, n_where.input[1]) - n_reshape = find_input_node(model, n_where.input[2]) - - if n_equal.op_type != "Equal" or n_cos.op_type != "ConstantOfShape" or n_reshape.op_type != "Reshape": - continue - - n_reshape_e = find_input_node(model, n_equal.input[0]) - n_mul = find_input_node(model, n_equal.input[1]) - if n_reshape_e != n_reshape or n_mul.op_type != "Mul": - continue - - n_cos_m = find_input_node(model, n_mul.input[0]) - n_constant = find_input_node(model, n_mul.input[1]) - if n_cos_m != n_cos or n_constant.op_type != "Constant": - continue - - n_concat = find_input_node(model, n_reshape.input[0]) - n_constant_r = find_input_node(model, n_reshape.input[1]) - if n_concat.op_type != "Concat" or n_constant_r.op_type != "Constant": - continue - - n_input_candidates = [] - for concat_in in n_concat.input: - n_unsqueeze = find_input_node(model, concat_in) - if n_unsqueeze.op_type != "Unsqueeze": - break - n_gather = find_input_node(model, n_unsqueeze.input[0]) - if n_gather.op_type != "Gather": - break - n_shape = find_input_node(model, n_gather.input[0]) - n_constant_g = find_input_node(model, n_gather.input[1]) - if n_shape.op_type != "Shape" or n_constant_g.op_type != "Constant": - break - n_input = n_shape.input[0] - if n_input not in model_inputs_names: - break - n_input_candidates.append(n_input) - - if not n_input_candidates or not all(elem == n_input_candidates[0] for elem in n_input_candidates): - continue - - index = model_inputs_names.index(n_input_candidates[0]) - expand_out = model.graph.value_info.add() - expand_out.name = expand_node.output[0] - expand_out.type.CopyFrom(model.graph.input[index].type) - return model - - -# LayerNorm PostProcess - - -def find_nodes(graph, op_type): - nodes = [] - for node in graph.node: - if node.op_type == op_type: - nodes.append(node) - return nodes - - -def is_type(node, op_type): - if node is None or isinstance(node, list): - return False - return node.op_type == op_type - - -def add_const(model, name, output, t_value=None, f_value=None): - const_node = model.graph.node.add() - const_node.op_type = "Constant" - const_node.name = name - const_node.output.extend([output]) - attr = const_node.attribute.add() - attr.name = "value" - if t_value is not None: - attr.type = 4 - attr.t.CopyFrom(t_value) - else: - attr.type = 1 - attr.f = f_value - return const_node - - -def layer_norm_transform(model): - # DEPRECATED: This pass is no longer needed as the transform is handled at the backend. - # Converting below subgraph - # - # input - # | - # ReduceMean - # | - # Sub Constant - # _||_____ | - # | | | - # | | | - # | (optional) Cast (optional) Cast - # | | | - # | | ____________________| - # | | | - # | Pow - # | | - # | ReduceMean - # | | - # | Add - # | | - # |__ __Sqrt - # | | - # Div (weight) - # | | - # | _____| - # | | - # Mul (bias) - # | | - # | _____| - # | | - # Add - # | - # output - # - # to the below subgraph - # - # input (weight) (bias) - # | | | - # | _______| | - # | | ________________| - # | | | - # LayerNormalization - # | - # output - graph = model.graph - - nodes_ReduceMean = find_nodes(graph, "ReduceMean") # noqa: N806 - - id = 0 - layer_norm_nodes = [] - remove_nodes = [] - for reduce_mean in nodes_ReduceMean: - # check that reduce_mean output is Sub - sub = find_output_node(model, reduce_mean.output[0]) - if not is_type(sub, "Sub"): - continue - - # check that sub output[0] is Div and output[1] is Pow - pow, div = find_output_node(model, sub.output[0]) - if is_type(pow, "Cast"): - # During an update in PyTorch, Cast nodes are inserted between Sub and Pow. - remove_nodes += [pow] - pow = find_output_node(model, pow.output[0]) - if not is_type(pow, "Pow"): - continue - cast_pow = find_input_node(model, pow.input[1]) - if not is_type(cast_pow, "Cast"): - continue - remove_nodes += [cast_pow] - if not is_type(div, "Div") or not is_type(pow, "Pow"): - continue - - # check that pow ouput is ReduceMean - reduce_mean2 = find_output_node(model, pow.output[0]) - if not is_type(reduce_mean2, "ReduceMean"): - continue - - # check that reduce_mean2 output is Add - add = find_output_node(model, reduce_mean2.output[0]) - if not is_type(add, "Add"): - continue - - # check that add output is Sqrt - sqrt = find_output_node(model, add.output[0]) - if not is_type(sqrt, "Sqrt"): - continue - - # check that sqrt output is div - if div != find_output_node(model, sqrt.output[0]): - continue - - # check if div output is Mul - optional_mul = find_output_node(model, div.output[0]) - if not is_type(optional_mul, "Mul"): - optional_mul = None - continue # default bias and weight not supported - - # check if mul output is Add - if optional_mul is not None: - optional_add = find_output_node(model, optional_mul.output[0]) - else: - optional_add = find_output_node(model, div.output[0]) - if not is_type(optional_add, "Add"): - optional_add = None - continue # default bias and weight not supported - - # add nodes to remove_nodes - remove_nodes.extend([reduce_mean, sub, div, pow, reduce_mean2, add, sqrt]) - - # create LayerNorm node - layer_norm_input = [] - layer_norm_output = [] - - layer_norm_input.append(reduce_mean.input[0]) - - if optional_mul is not None: - remove_nodes.append(optional_mul) - weight = optional_mul.input[1] - layer_norm_input.append(weight) - - if optional_add is not None: - remove_nodes.append(optional_add) - bias = optional_add.input[1] - layer_norm_input.append(bias) - - if optional_add is not None: - layer_norm_output.append(optional_add.output[0]) - elif optional_mul is not None: - layer_norm_output.append(optional_mul.output[0]) - else: - layer_norm_output.append(div.output[0]) - - layer_norm_output.append("saved_mean_" + str(id)) - layer_norm_output.append("saved_inv_std_var_" + str(id)) - - epsilon_node = find_input_node(model, add.input[1]) - epsilon = epsilon_node.attribute[0].t.raw_data - epsilon = struct.unpack("f", epsilon)[0] - - layer_norm = helper.make_node( - "LayerNormalization", - layer_norm_input, - layer_norm_output, - "LayerNormalization_" + str(id), - None, - axis=reduce_mean.attribute[0].ints[0], - epsilon=epsilon, - ) - layer_norm_nodes.append(layer_norm) - id += 1 - - # remove orphan constant nodes - for constant in graph.node: - if constant.op_type == "Constant" and constant not in remove_nodes: - is_orphan = True - for out_name in constant.output: - out = find_output_node(model, out_name) - if out not in remove_nodes: - is_orphan = False - if is_orphan: - remove_nodes.append(constant) - - all_nodes = [] - for node in graph.node: - if node not in remove_nodes: - all_nodes.append(node) - - for node in layer_norm_nodes: - all_nodes.append(node) # noqa: PERF402 - - graph.ClearField("node") - graph.node.extend(all_nodes) - return model - - -# Fuse SoftmaxCrossEntropy - - -def fuse_softmaxNLL_to_softmaxCE(onnx_model): # noqa: N802 - # Converting below subgraph - # - # (subgraph) - # | - # LogSoftmax (target) (optional weight) - # | | | - # nll_loss/NegativeLogLikelihoodLoss - # | - # output - # - # to the following - # - # (subgraph) (target) (optional weight) - # | | _____| - # | | | - # SparseSoftmaxCrossEntropy - # | - # output - nll_count = 0 - while True: - nll_count = nll_count + 1 - nll_loss_node = None - nll_loss_node_index = 0 - for nll_loss_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007 - if node.op_type == "nll_loss" or node.op_type == "NegativeLogLikelihoodLoss": - nll_loss_node = node - break - - if nll_loss_node is None: - break - - softmax_node = None - softmax_node_index = 0 - label_input_name = None - weight_input_name = None - for softmax_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007 - if node.op_type == "LogSoftmax": - # has to be connected to nll_loss - if len(nll_loss_node.input) > 2: - weight_input_name = nll_loss_node.input[2] - if node.output[0] == nll_loss_node.input[0]: - softmax_node = node - label_input_name = nll_loss_node.input[1] - break - elif node.output[0] == nll_loss_node.input[1]: - softmax_node = node - label_input_name = nll_loss_node.input[0] - break - else: - if softmax_node is not None: - break - - if softmax_node is None: - break - - # delete nll_loss and LogSoftmax nodes in order - if nll_loss_node_index < softmax_node_index: - del onnx_model.graph.node[softmax_node_index] - del onnx_model.graph.node[nll_loss_node_index] - else: - del onnx_model.graph.node[nll_loss_node_index] - del onnx_model.graph.node[softmax_node_index] - - probability_output_name = softmax_node.output[0] - node = onnx_model.graph.node.add() - inputs = ( - [softmax_node.input[0], label_input_name, weight_input_name] - if weight_input_name - else [softmax_node.input[0], label_input_name] - ) - node.CopyFrom( - onnx.helper.make_node( - "SparseSoftmaxCrossEntropy", - inputs, - [nll_loss_node.output[0], probability_output_name], - "nll_loss_node_" + str(nll_count), - ) - ) - - return onnx_model diff --git a/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py b/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py deleted file mode 100644 index f57f55d14eb1b..0000000000000 --- a/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py +++ /dev/null @@ -1,144 +0,0 @@ -import sys -import threading -import time - - -class OutputGrabber: - """ - Class used to grab standard output or another stream. - """ - - escape_char = "\b" - - def __init__(self, stream=None, threaded=False): - self.origstream = stream - self.threaded = threaded - if self.origstream is None: - self.origstream = sys.stdout - self.origstreamfd = self.origstream.fileno() - self.capturedtext = "" - # Create a pipe so the stream can be captured: - self.pipe_out, self.pipe_in = os.pipe() - - def __enter__(self): - self.start() - return self - - def __exit__(self, type, value, traceback): - self.stop() - - def start(self): - """ - Start capturing the stream data. - """ - self.capturedtext = "" - # Save a copy of the stream: - self.streamfd = os.dup(self.origstreamfd) - # Replace the original stream with our write pipe: - os.dup2(self.pipe_in, self.origstreamfd) - if self.threaded: - # Start thread that will read the stream: - self.workerThread = threading.Thread(target=self.readOutput) - self.workerThread.start() - # Make sure that the thread is running and os.read() has executed: - time.sleep(0.01) - - def stop(self): - """ - Stop capturing the stream data and save the text in `capturedtext`. - """ - # Print the escape character to make the readOutput method stop: - self.origstream.write(self.escape_char) - # Flush the stream to make sure all our data goes in before - # the escape character: - self.origstream.flush() - if self.threaded: - # wait until the thread finishes so we are sure that - # we have until the last character: - self.workerThread.join() - else: - self.readOutput() - # Close the pipe: - os.close(self.pipe_in) - os.close(self.pipe_out) - # Restore the original stream: - os.dup2(self.streamfd, self.origstreamfd) - # Close the duplicate stream: - os.close(self.streamfd) - - def readOutput(self): - """ - Read the stream data (one byte at a time) - and save the text in `capturedtext`. - """ - while True: - char = os.read(self.pipe_out, 1).decode(self.origstream.encoding) - if not char or self.escape_char in char: - break - self.capturedtext += char - - -import os # noqa: E402 -import unittest # noqa: E402 - -import numpy as np # noqa: E402, F401 -import torch # noqa: E402 -import torch.nn as nn # noqa: E402 -import torch.nn.functional as F # noqa: E402 - -from onnxruntime.capi import _pybind_state as torch_ort_eager # noqa: E402, F401 -from onnxruntime.training import optim, orttrainer, orttrainer_options # noqa: E402, F401 - - -def my_loss(x, target): - return F.nll_loss(F.log_softmax(x, dim=1), target) - - -class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, x, target): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - return my_loss(out, target) - - -class OrtEPTests(unittest.TestCase): - def test_external_graph_transformer_triggering(self): - input_size = 784 - hidden_size = 500 - num_classes = 10 - batch_size = 128 - model = NeuralNet(input_size, hidden_size, num_classes) - - model_desc = { - "inputs": [ - ("x", [batch_size, input_size]), - ( - "target", - [ - batch_size, - ], - ), - ], - "outputs": [("loss", [], True)], - } - optim_config = optim.SGDConfig() - opts = orttrainer.ORTTrainerOptions({"device": {"id": "cpu"}}) - model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - # because orttrainer is lazy initialized, feed in a random data to trigger the graph transformer - data = torch.rand(batch_size, input_size) - target = torch.randint(0, 10, (batch_size,)) - - with OutputGrabber() as out: - model.train_step(data, target) - assert "******************Trigger Customized Graph Transformer: MyGraphTransformer!" in out.capturedtext - - -if __name__ == "__main__": - unittest.main() diff --git a/orttraining/orttraining/test/external_transformer/test_exeternal_transformers/test_external_transformers.cc b/orttraining/orttraining/test/external_transformer/test_exeternal_transformers/test_external_transformers.cc deleted file mode 100644 index 00e933dd14914..0000000000000 --- a/orttraining/orttraining/test/external_transformer/test_exeternal_transformers/test_external_transformers.cc +++ /dev/null @@ -1,35 +0,0 @@ -#include "core/optimizer/rewrite_rule.h" -#include "orttraining/core/optimizer/graph_transformer_registry.h" -#include "onnx/defs/schema.h" -#include -#include - -namespace onnxruntime { -namespace training { - -class MyRewriteRule : public RewriteRule { - public: - MyRewriteRule() noexcept - : RewriteRule("MyRewriteRule") { - } - std::vector TargetOpTypes() const noexcept override { - return {}; - } - - private: - bool SatisfyCondition(const Graph& /*graph*/, const Node& /*node*/, const logging::Logger& /*logger*/) const override { - return true; - } - - Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/, const logging::Logger& /*logger*/) const override { - std::cout << "******************Trigger Customized Graph Transformer: MyGraphTransformer!" << std::endl; - return Status::OK(); - } -}; - -void RegisterTrainingExternalTransformers() { - ONNX_REGISTER_EXTERNAL_REWRITE_RULE(MyRewriteRule, Level1, true); -} - -} // namespace training -} // namespace onnxruntime diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 20b9354d85745..b774fec11cc8d 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -35,6 +35,7 @@ #ifdef ENABLE_TRITON #include "orttraining/core/optimizer/triton_fusion.h" #endif +#include "orttraining/core/optimizer/conv1d_replacement.h" #include @@ -1199,6 +1200,103 @@ TEST_P(QDQFusionTestsParameterized, CheckModelComposition) { ASSERT_EQ(op_to_count_post_fusion["com.microsoft.FakeQuant"], 1); } +TEST_F(GraphTransformationTests, Conv1dReplacement) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); + return Status::OK(); + }; + + for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { + for (auto group : {1, 2}) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); + auto out_channel = 64; + auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); + + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel / group, 1}, {-1.0f, 1.0f}); + auto* conv_output = builder.MakeOutput(); + + auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); + conv_node.AddAttribute("dilations", std::vector{1}); + conv_node.AddAttribute("kernel_shape", std::vector{1}); + conv_node.AddAttribute("strides", std::vector{1}); + conv_node.AddAttribute("group", static_cast(group)); + }; + + auto post_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 0); + // after graph transformation, the graph should have 1 squeeze, 2 split, group matmul, 1 concat + TEST_RETURN_IF_NOT(op_count_map["Squeeze"] == 1); + TEST_RETURN_IF_NOT(op_count_map["Split"] == 2); + TEST_RETURN_IF_NOT(op_count_map["MatMul"] == group); + TEST_RETURN_IF_NOT(op_count_map["Concat"] == 1); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } + } +} + +TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); + return Status::OK(); + }; + + // "group" is 3 so conv not replaced + for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); + auto out_channel = 64; + auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); + + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel / 3, 1}, {-1.0f, 1.0f}); + auto* conv_output = builder.MakeOutput(); + + auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); + conv_node.AddAttribute("dilations", std::vector{1}); + conv_node.AddAttribute("kernel_shape", std::vector{1}); + conv_node.AddAttribute("strides", std::vector{1}); + conv_node.AddAttribute("group", static_cast(3)); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker)); + } + + // "kernel_shape" is not 1 so conv not replaced + for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); + auto out_channel = 64; + auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); + + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel, 1}, {-1.0f, 1.0f}); + auto* conv_output = builder.MakeOutput(); + + auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); + conv_node.AddAttribute("dilations", std::vector{1}); + conv_node.AddAttribute("kernel_shape", std::vector{2}); + conv_node.AddAttribute("strides", std::vector{1}); + conv_node.AddAttribute("group", static_cast(1)); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker)); + } +} + INSTANTIATE_TEST_SUITE_P( QDQFusionTests, QDQFusionTestsParameterized, diff --git a/orttraining/orttraining/test/python/_test_commons.py b/orttraining/orttraining/test/python/_test_commons.py index 1413d59096832..fb7e62551de63 100644 --- a/orttraining/orttraining/test/python/_test_commons.py +++ b/orttraining/orttraining/test/python/_test_commons.py @@ -1,26 +1,7 @@ -import copy -import math import os import subprocess import sys -import numpy as np -import onnx -import torch -from numpy.testing import assert_allclose - -import onnxruntime -from onnxruntime.training import _utils, optim - - -def _single_run(execution_file, scenario, checkopint_dir=None): - cmd = [sys.executable, execution_file] - if scenario: - cmd += ["--scenario", scenario] - if checkopint_dir: - cmd += ["--checkpoint_dir", checkopint_dir] - assert subprocess.call(cmd) == 0 - def is_windows(): return sys.platform.startswith("win") @@ -46,197 +27,3 @@ def run_subprocess(args, cwd=None, capture=False, dll_path=None, shell=False, en if log: log.debug("Subprocess completed. Return code=" + str(completed_process.returncode)) return completed_process - - -def legacy_constant_lr_scheduler(global_step, initial_lr, total_steps, warmup): - num_warmup_steps = warmup * total_steps - if global_step < num_warmup_steps: - new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps)) - else: - new_lr = initial_lr - return new_lr - - -def legacy_cosine_lr_scheduler(global_step, initial_lr, total_steps, warmup, cycles): - num_warmup_steps = warmup * total_steps - if global_step < num_warmup_steps: - new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps)) - else: - progress = float(global_step - num_warmup_steps) / float(max(1, total_steps - num_warmup_steps)) - new_lr = initial_lr * max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(cycles) * 2.0 * progress))) - return new_lr - - -def legacy_linear_lr_scheduler(global_step, initial_lr, total_steps, warmup): - num_warmup_steps = warmup * total_steps - if global_step < num_warmup_steps: - new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps)) - else: - new_lr = initial_lr * max(0.0, float(total_steps - global_step) / float(max(1, total_steps - num_warmup_steps))) - return new_lr - - -def legacy_poly_lr_scheduler(global_step, initial_lr, total_steps, warmup, power, lr_end): - num_warmup_steps = warmup * total_steps - if global_step < num_warmup_steps: - new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps)) - elif global_step > total_steps: - new_lr = lr_end - else: - lr_range = initial_lr - lr_end - decay_steps = total_steps - num_warmup_steps - pct_remaining = 1 - (global_step - num_warmup_steps) / decay_steps - decay = lr_range * pct_remaining**power + lr_end - new_lr = decay - return new_lr - - -def generate_dummy_optim_state(model, optimizer): - np.random.seed(0) - if not (isinstance(optimizer, (optim.AdamConfig, optim.LambConfig))): - return dict() - - moment_keys = ["Moment_1", "Moment_2"] - uc_key = "Update_Count" - step_key = "Step" - shared_state_key = "shared_optimizer_state" - - optim_state = dict() - weight_shape_map = dict() - if isinstance(model, torch.nn.Module): - weight_shape_map = {name: param.size() for name, param in model.named_parameters()} - elif isinstance(model, onnx.ModelProto): - weight_shape_map = {n.name: n.dims for n in model.graph.initializer} - else: - raise ValueError("'model' must be either 'torch.nn.Module' or 'onnx.ModelProto'") - - for weight_name, weight_shape in weight_shape_map.items(): - per_weight_state = dict() - for moment in moment_keys: - per_weight_state[moment] = np.random.uniform(-2, 2, weight_shape).astype(np.float32) - if isinstance(optimizer, optim.AdamConfig): - per_weight_state[uc_key] = np.full([1], 5, dtype=np.int64) - optim_state[weight_name] = copy.deepcopy(per_weight_state) - if isinstance(optimizer, optim.LambConfig): - step_val = np.full([1], 5, dtype=np.int64) - optim_state[shared_state_key] = {step_key: step_val} - return {"optimizer": optim_state, "trainer_options": {"optimizer_name": optimizer.name}} - - -def _load_pytorch_transformer_model(device, dynamic_axes=False, legacy_api=False, data_dir=None): - # Loads external Pytorch TransformerModel into utils - root = "samples" - if not os.path.exists(root): - root = os.path.normpath( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "..", "samples") - ) - if not os.path.exists(root): - raise FileNotFoundError("Unable to find folder 'samples', tried %r." % root) - pytorch_transformer_path = os.path.join(root, "python", "training", "orttrainer", "pytorch_transformer") - pt_model_path = os.path.join(pytorch_transformer_path, "pt_model.py") - pt_model = _utils.import_module_from_file(pt_model_path) - ort_utils_path = os.path.join(pytorch_transformer_path, "ort_utils.py") - ort_utils = _utils.import_module_from_file(ort_utils_path) - utils_path = os.path.join(pytorch_transformer_path, "utils.py") - utils = _utils.import_module_from_file(utils_path) - - # Modeling - model = pt_model.TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device) - my_loss = ort_utils.my_loss - if legacy_api: - if dynamic_axes: - model_desc = ort_utils.legacy_transformer_model_description_dynamic_axes() - else: - model_desc = ort_utils.legacy_transformer_model_description() - else: - if dynamic_axes: - model_desc = ort_utils.transformer_model_description_dynamic_axes() - else: - model_desc = ort_utils.transformer_model_description() - - # Preparing data - train_data, val_data, test_data = utils.prepare_data(device, 20, 20, data_dir) - return model, model_desc, my_loss, utils.get_batch, train_data, val_data, test_data - - -def generate_random_input_from_bart_model_desc(desc, seed=1, device="cuda:0"): - """Generates a sample input for the BART model using the model desc""" - - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - dtype = torch.int64 - vocab_size = 30528 - sample_input = [] - for _index, input in enumerate(desc["inputs"]): - size = [] - for s in input[1]: - if isinstance(s, (int)): - size.append(s) - else: - size.append(1) - sample_input.append(torch.randint(0, vocab_size, tuple(size), dtype=dtype).to(device)) - return sample_input - - -def _load_bart_model(): - bart_onnx_model_path = os.path.join("testdata", "bart_tiny.onnx") - model = onnx.load(bart_onnx_model_path) - batch = 2 - seq_len = 1024 - model_desc = { - "inputs": [ - ( - "src_tokens", - [batch, seq_len], - ), - ( - "prev_output_tokens", - [batch, seq_len], - ), - ( - "target", - [batch * seq_len], - ), - ], - "outputs": [("loss", [], True)], - } - - return model, model_desc - - -def assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint, reshape_states=False): - """Assert that the two ORTTrainer (hierarchical) state dictionaries are very close for all states""" - - assert ("model" in state_dict_pre_checkpoint) == ("model" in state_dict_post_checkpoint) - assert ("optimizer" in state_dict_pre_checkpoint) == ("optimizer" in state_dict_post_checkpoint) - - if "model" in state_dict_pre_checkpoint: - for model_state_key in state_dict_pre_checkpoint["model"]["full_precision"]: - if reshape_states: - assert_allclose( - state_dict_pre_checkpoint["model"]["full_precision"][model_state_key], - state_dict_post_checkpoint["model"]["full_precision"][model_state_key].reshape( - state_dict_pre_checkpoint["model"]["full_precision"][model_state_key].shape - ), - ) - else: - assert_allclose( - state_dict_pre_checkpoint["model"]["full_precision"][model_state_key], - state_dict_post_checkpoint["model"]["full_precision"][model_state_key], - ) - - if "optimizer" in state_dict_pre_checkpoint: - for model_state_key in state_dict_pre_checkpoint["optimizer"]: - for optimizer_state_key in state_dict_pre_checkpoint["optimizer"][model_state_key]: - if reshape_states: - assert_allclose( - state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key], - state_dict_post_checkpoint["optimizer"][model_state_key][optimizer_state_key].reshape( - state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key].shape - ), - ) - else: - assert_allclose( - state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key], - state_dict_post_checkpoint["optimizer"][model_state_key][optimizer_state_key], - ) diff --git a/orttraining/orttraining/test/python/_test_helpers.py b/orttraining/orttraining/test/python/_test_helpers.py index a9a4c7b1cc2ef..8f2a18b5ec00b 100644 --- a/orttraining/orttraining/test/python/_test_helpers.py +++ b/orttraining/orttraining/test/python/_test_helpers.py @@ -1,30 +1,11 @@ import copy import os -import numpy as np import torch from numpy.testing import assert_allclose -from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer -from onnxruntime.training import orttrainer - -try: - from onnxruntime.training.ortmodule import ORTModule - from onnxruntime.training.ortmodule._fallback import ORTModuleInitException - from onnxruntime.training.ortmodule._graph_execution_manager_factory import ( # noqa: F401 - GraphExecutionManagerFactory, - ) -except ImportError: - # Some pipelines do not contain ORTModule - pass -except Exception as e: - from onnxruntime.training.ortmodule._fallback import ORTModuleInitException - - if isinstance(e, ORTModuleInitException): - # ORTModule is present but not ready to run - # That is OK because this file is also used by ORTTrainer tests - pass - raise +from onnxruntime.training.ortmodule import ORTModule +from onnxruntime.training.ortmodule._graph_execution_manager_factory import GraphExecutionManagerFactory # noqa: F401 def is_all_or_nothing_fallback_enabled(model, policy=None): @@ -66,103 +47,6 @@ def assert_model_outputs(output_a, output_b, verbose=False, rtol=1e-7, atol=0): assert_allclose(output_a, output_b, rtol=rtol, atol=atol, err_msg="Model output value mismatch") -def assert_onnx_weights(model_a, model_b, verbose=False, rtol=1e-7, atol=0): - r"""Asserts whether weight difference between models a and b differences are within specified tolerance - - Compares the weights of two different ONNX models (model_a and model_b) - and raises AssertError when they diverge by more than atol or rtol - - Args: - model_a, model_b (ORTTrainer): Two instances of ORTTrainer with the same model structure - verbose (bool, default is False): if True, prints absolute difference for each weight - rtol (float, default is 1e-7): Max relative difference - atol (float, default is 1e-4): Max absolute difference - """ - assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, orttrainer.ORTTrainer) - state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b._training_session.get_state() - assert len(state_dict_a.items()) == len(state_dict_b.items()) - _assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol) - - -def assert_legacy_onnx_weights(model_a, model_b, verbose=False, rtol=1e-7, atol=0): - r"""Asserts whether weight difference between models a and b differences are within specified tolerance - - Compares the weights of a legacy model model_a and experimental model_b model - and raises AssertError when they diverge by more than atol or rtol. - - Args: - model_a (ORTTrainer): Instance of legacy ORTTrainer - model_b (ORTTrainer): Instance of experimental ORTTrainer - verbose (bool, default is False): if True, prints absolute difference for each weight. - rtol (float, default is 1e-7): Max relative difference - atol (float, default is 1e-4): Max absolute difference - """ - assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, Legacy_ORTTrainer) - state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b.session.get_state() - assert len(state_dict_a.items()) == len(state_dict_b.items()) - _assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol) - - -def _assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol): - r"""Asserts whether dicts a and b value differences are within specified tolerance - - Compares the weights of two model's state_dict dicts and raises AssertError - when they diverge by more than atol or rtol - - Args: - model_a (ORTTrainer): Instance of legacy ORTTrainer - model_b (ORTTrainer): Instance of experimental ORTTrainer - verbose (bool, default is False): if True, prints absolute difference for each weight. - rtol (float, default is 1e-7): Max relative difference - atol (float, default is 1e-4): Max absolute difference - """ - - for (a_name, a_val), (_b_name, b_val) in zip(state_dict_a.items(), state_dict_b.items()): - np_a_vals = np.array(a_val).flatten() - np_b_vals = np.array(b_val).flatten() - assert np_a_vals.shape == np_b_vals.shape - if verbose: - print(f"Weight name: {a_name}: absolute difference: {np.abs(np_a_vals-np_b_vals).max()}") - assert_allclose(a_val, b_val, rtol=rtol, atol=atol, err_msg=f"Weight mismatch for {a_name}") - - -def assert_optim_state(expected_state, actual_state, rtol=1e-7, atol=0): - r"""Asserts whether optimizer state differences are within specified tolerance - - Compares the expected and actual optimizer states of dicts and raises AssertError - when they diverge by more than atol or rtol. - The optimizer dict is of the form: - model_weight_name: - { - "Moment_1": moment1_tensor, - "Moment_2": moment2_tensor, - "Update_Count": update_tensor # if optimizer is adam, absent otherwise - }, - ... - "shared_optimizer_state": # if optimizer is shared, absent otherwise. - So far, only lamb optimizer uses this. - { - "step": step_tensor # int array of size 1 - } - - Args: - expected_state (dict(dict())): Expected optimizer state - actual_state (dict(dict())): Actual optimizer state - rtol (float, default is 1e-7): Max relative difference - atol (float, default is 0): Max absolute difference - """ - assert expected_state.keys() == actual_state.keys() - for param_name, a_state in actual_state.items(): - for k, v in a_state.items(): - assert_allclose( - v, - expected_state[param_name][k], - rtol=rtol, - atol=atol, - err_msg=f"Optimizer state mismatch for param {param_name}, key {k}", - ) - - def is_dynamic_axes(model): # Check inputs for inp in model._torch_module._execution_manager(model._is_training())._onnx_models.optimized_model.graph.input: diff --git a/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py b/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py deleted file mode 100644 index d5298cf8e860e..0000000000000 --- a/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py +++ /dev/null @@ -1,325 +0,0 @@ -import os -import unittest - -import torch -import torch.nn as nn -from orttraining_test_bert_postprocess import postprocess_model -from orttraining_test_data_loader import create_ort_test_dataloader -from orttraining_test_transformers import BertForPreTraining, BertModelTest -from orttraining_test_utils import map_optimizer_attributes - -import onnxruntime -from onnxruntime.capi.ort_trainer import ( # noqa: F401 - IODescription, - LossScaler, - ModelDescription, - ORTTrainer, - generate_sample, -) - -torch.manual_seed(1) -onnxruntime.set_seed(1) - - -class Test_PostPasses(unittest.TestCase): # noqa: N801 - def get_onnx_model( - self, model, model_desc, inputs, device, _enable_internal_postprocess=True, _extra_postprocess=None - ): - lr_desc = IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ) - model = ORTTrainer( - model, - None, - model_desc, - "LambOptimizer", - map_optimizer_attributes, - lr_desc, - device, - world_rank=0, - world_size=1, - _opset_version=14, - _enable_internal_postprocess=_enable_internal_postprocess, - _extra_postprocess=_extra_postprocess, - ) - - model.train_step(*inputs) - return model.onnx_model_ - - def count_all_nodes(self, model): - return len(model.graph.node) - - def count_nodes(self, model, node_type): - count = 0 - for node in model.graph.node: - if node.op_type == node_type: - count += 1 - return count - - def find_nodes(self, model, node_type): - nodes = [] - for node in model.graph.node: - if node.op_type == node_type: - nodes.append(node) - return nodes - - def get_name(self, name): - if os.path.exists(name): - return name - rel = os.path.join("testdata", name) - if os.path.exists(rel): - return rel - this = os.path.dirname(__file__) - data = os.path.join(this, "..", "..", "..", "..", "onnxruntime", "test", "testdata") - res = os.path.join(data, name) - if os.path.exists(res): - return res - raise FileNotFoundError(f"Unable to find '{name}' or '{rel}' or '{res}'") - - def test_layer_norm(self): - class LayerNormNet(nn.Module): - def __init__(self, target): - super().__init__() - self.ln_1 = nn.LayerNorm(10) - self.loss = nn.CrossEntropyLoss() - self.target = target - - def forward(self, x): - output1 = self.ln_1(x) - loss = self.loss(output1, self.target) - return loss, output1 - - device = torch.device("cpu") - target = torch.ones(20, 10, 10, dtype=torch.int64).to(device) - model = LayerNormNet(target) - input = torch.randn(20, 5, 10, 10, dtype=torch.float32).to(device) - - input_desc = IODescription("input", [], "float32") - output0_desc = IODescription("output0", [], "float32") - output1_desc = IODescription("output1", [20, 5, 10, 10], "float32") - model_desc = ModelDescription([input_desc], [output0_desc, output1_desc]) - - learning_rate = torch.tensor([1.0000000e00]).to(device) - input_args = [input, learning_rate] - - onnx_model = self.get_onnx_model(model, model_desc, input_args, device) - - count_layer_norm = self.count_nodes(onnx_model, "LayerNormalization") - count_nodes = self.count_all_nodes(onnx_model) - - assert count_layer_norm == 0 - assert count_nodes == 3 - - def test_expand(self): - class ExpandNet(nn.Module): - def __init__(self, target): - super().__init__() - self.loss = nn.CrossEntropyLoss() - self.target = target - self.linear = torch.nn.Linear(2, 2) - - def forward(self, x, x1): - output = x.expand_as(x1) - output = self.linear(output) - output = output + output - loss = self.loss(output, self.target) - return loss, output - - device = torch.device("cpu") - target = torch.ones(5, 5, 2, dtype=torch.int64).to(device) - model = ExpandNet(target).to(device) - - x = torch.randn(5, 3, 1, 2, dtype=torch.float32).to(device) - x1 = torch.randn(5, 3, 5, 2, dtype=torch.float32).to(device) - - input0_desc = IODescription("x", [5, 3, 1, 2], "float32") - input1_desc = IODescription("x1", [5, 3, 5, 2], "float32") - output0_desc = IODescription("output0", [], "float32") - output1_desc = IODescription("output1", [5, 3, 5, 2], "float32") - model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc]) - - learning_rate = torch.tensor([1.0000000e00]).to(device) - input_args = [x, x1, learning_rate] - - onnx_model = self.get_onnx_model(model, model_desc, input_args, device) - - # check that expand output has shape - expand_nodes = self.find_nodes(onnx_model, "Expand") - assert len(expand_nodes) == 1 - - model_info = onnx_model.graph.value_info - assert model_info[0].name == expand_nodes[0].output[0] - assert model_info[0].type == onnx_model.graph.input[1].type - - def test_bert(self): - device = torch.device("cpu") - - model_tester = BertModelTest.BertModelTester(self) - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = model_tester.prepare_config_and_inputs() - - model = BertForPreTraining(config=config) - model.eval() - - loss, prediction_scores, seq_relationship_score = model( - input_ids, - attention_mask=input_mask, - token_type_ids=token_type_ids, - masked_lm_labels=token_labels, - next_sentence_label=sequence_labels, - ) - - model_desc = ModelDescription( - [ - model_tester.input_ids_desc, - model_tester.attention_mask_desc, - model_tester.token_type_ids_desc, - model_tester.masked_lm_labels_desc, - model_tester.next_sentence_label_desc, - ], - [model_tester.loss_desc, model_tester.prediction_scores_desc, model_tester.seq_relationship_scores_desc], - ) - - from collections import namedtuple - - MyArgs = namedtuple( - "MyArgs", "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len" - ) - args = MyArgs( - local_rank=0, - world_size=1, - max_steps=100, - learning_rate=0.00001, - warmup_proportion=0.01, - batch_size=13, - seq_len=7, - ) - - dataset_len = 100 - dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device) - learning_rate = torch.tensor(1.0e0, dtype=torch.float32).to(device) - for b in dataloader: - batch = b - break - learning_rate = torch.tensor([1.00e00]).to(device) - inputs = [*batch, learning_rate] - - onnx_model = self.get_onnx_model(model, model_desc, inputs, device, _extra_postprocess=postprocess_model) - - self._bert_helper(onnx_model) - - def _bert_helper(self, onnx_model): - # count layer_norm - count_layer_norm = self.count_nodes(onnx_model, "LayerNormalization") - assert count_layer_norm == 0 - - # get expand node and check output shape - expand_nodes = self.find_nodes(onnx_model, "Expand") - assert len(expand_nodes) == 1 - - model_info = onnx_model.graph.value_info - assert model_info[0].name == expand_nodes[0].output[0] - assert model_info[0].type == onnx_model.graph.input[0].type - - def test_extra_postpass(self): - def postpass_replace_first_add_with_sub(model): - # this post pass replaces the first Add node with Sub in the model. - # Previous graph - # (subgraph 1) (subgraph 2) - # | | - # | | - # |________ ________| - # | | - # Add - # | - # (subgraph 3) - # - # Post graph - # (subgraph 1) (subgraph 2) - # | | - # | | - # |________ ________| - # | | - # Sub - # | - # (subgraph 3) - add_nodes = [n for n in model.graph.node if n.op_type == "Add"] - add_nodes[0].op_type = "Sub" - - class MultiAdd(nn.Module): - def __init__(self, target): - super().__init__() - self.loss = nn.CrossEntropyLoss() - self.target = target - self.linear = torch.nn.Linear(2, 2, bias=False) - - def forward(self, x, x1): - output = x + x1 - output = output + x - output = output + x1 - output = self.linear(output) - loss = self.loss(output, self.target) - return loss, output - - device = torch.device("cpu") - target = torch.ones(5, 2, dtype=torch.int64).to(device) - model = MultiAdd(target).to(device) - - x = torch.randn(5, 5, 2, dtype=torch.float32).to(device) - x1 = torch.randn(5, 5, 2, dtype=torch.float32).to(device) - - input0_desc = IODescription("x", [5, 5, 2], "float32") - input1_desc = IODescription("x1", [5, 5, 2], "float32") - output0_desc = IODescription("output0", [], "float32") - output1_desc = IODescription("output1", [5, 5, 2], "float32") - model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc]) - - learning_rate = torch.tensor([1.0000000e00]).to(device) - input_args = [x, x1, learning_rate] - - onnx_model = self.get_onnx_model( - model, model_desc, input_args, device, _extra_postprocess=postpass_replace_first_add_with_sub - ) - - # check that extra postpass is called, and called only once. - add_nodes = self.find_nodes(onnx_model, "Add") - sub_nodes = self.find_nodes(onnx_model, "Sub") - assert len(add_nodes) == 2 - assert len(sub_nodes) == 1 - - unprocessed_onnx_model = self.get_onnx_model( - model, model_desc, input_args, device, _extra_postprocess=None, _enable_internal_postprocess=False - ) - # check that the model is unchanged. - add_nodes = self.find_nodes(unprocessed_onnx_model, "Add") - sub_nodes = self.find_nodes(unprocessed_onnx_model, "Sub") - assert len(add_nodes) == 3 - assert len(sub_nodes) == 0 - - processed_onnx_model = self.get_onnx_model( - unprocessed_onnx_model, - model_desc, - input_args, - device, - _extra_postprocess=postpass_replace_first_add_with_sub, - ) - # check that extra postpass is called, and called only once. - add_nodes = self.find_nodes(processed_onnx_model, "Add") - sub_nodes = self.find_nodes(processed_onnx_model, "Sub") - assert len(add_nodes) == 2 - assert len(sub_nodes) == 1 - - -if __name__ == "__main__": - unittest.main(module=__name__, buffer=True) diff --git a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py index 0e7e9d23ee627..5341cd053ac18 100644 --- a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py +++ b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py @@ -43,7 +43,7 @@ def run_ortmodule_ops_tests(cwd, log, transformers_cache): env = get_env_with_transformers_cache(transformers_cache) - command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_onnx_ops_ortmodule.py"] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ortmodule_onnx_ops.py"] run_subprocess(command, cwd=cwd, log=log, env=env).check_returncode() @@ -146,7 +146,7 @@ def run_data_sampler_tests(cwd, log): def run_hooks_tests(cwd, log): log.debug("Running: Data hooks tests") - command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_hooks.py"] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ortmodule_hooks.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() diff --git a/orttraining/orttraining/test/python/orttraining_run_bert_pretrain.py b/orttraining/orttraining/test/python/orttraining_run_bert_pretrain.py deleted file mode 100644 index eea733684f140..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_run_bert_pretrain.py +++ /dev/null @@ -1,801 +0,0 @@ -# ================== -import dataclasses -import datetime -import glob -import json -import logging -import os -import random -import shutil -import unittest -from concurrent.futures import ProcessPoolExecutor -from dataclasses import dataclass, field -from typing import Any, Dict, Optional - -import h5py -import numpy as np -import torch -import torch.distributed as dist -from torch.utils.data import DataLoader, Dataset, RandomSampler -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm -from transformers import BertConfig, BertForPreTraining, HfArgumentParser - -import onnxruntime as ort - -# need to override torch.onnx.symbolic_opset12.nll_loss to handle ignore_index == -100 cases. -# the fix for ignore_index == -100 cases is already in pytorch master. -# however to use current torch master is causing computation changes in many tests. -# eventually we will use pytorch with fixed nll_loss once computation -# issues are understood and solved. -import onnxruntime.capi.pt_patch -from onnxruntime.training import amp, optim, orttrainer -from onnxruntime.training.checkpoint import aggregate_checkpoints -from onnxruntime.training.optim import LinearWarmupLRScheduler, PolyWarmupLRScheduler # noqa: F401 - -# we cannot make full convergence run in nightly pipeling because of its timeout limit, -# max_steps is still needed to calculate learning rate. force_to_stop_max_steps is used to -# terminate the training before the pipeline run hit its timeout. -force_to_stop_max_steps = 2500 - -logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO -) -logger = logging.getLogger(__name__) - - -def get_rank(): - if not dist.is_available(): - return 0 - if not dist.is_initialized(): - return 0 - return dist.get_rank() - - -def is_main_process(args): - if hasattr(args, "world_rank"): - return args.world_rank in [-1, 0] - else: - return get_rank() == 0 - - -def bert_model_description(config): - vocab_size = config.vocab_size - new_model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", "max_seq_len_in_batch"], - ), - ( - "token_type_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "masked_lm_labels", - ["batch", "max_seq_len_in_batch"], - ), - ( - "next_sentence_label", - [ - "batch", - ], - ), - ], - "outputs": [ - ("loss", [], True), - ( - "prediction_scores", - ["batch", "max_seq_len_in_batch", vocab_size], - ), - ( - "seq_relationship_scores", - ["batch", 2], - ), - ], - } - return new_model_desc - - -def create_pretraining_dataset(input_file, max_pred_length, args): - train_data = pretraining_dataset(input_file=input_file, max_pred_length=max_pred_length) - train_sampler = RandomSampler(train_data) - train_dataloader = DataLoader( - train_data, sampler=train_sampler, batch_size=args.train_batch_size * args.n_gpu, num_workers=0, pin_memory=True - ) - return train_dataloader, input_file - - -class pretraining_dataset(Dataset): # noqa: N801 - def __init__(self, input_file, max_pred_length): - logger.info("pretraining_dataset: %s, max_pred_length: %d", input_file, max_pred_length) - self.input_file = input_file - self.max_pred_length = max_pred_length - f = h5py.File(input_file, "r") - keys = [ - "input_ids", - "input_mask", - "segment_ids", - "masked_lm_positions", - "masked_lm_ids", - "next_sentence_labels", - ] - self.inputs = [np.asarray(f[key][:]) for key in keys] - f.close() - - def __len__(self): - "Denotes the total number of samples" - return len(self.inputs[0]) - - def __getitem__(self, index): - [input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids, next_sentence_labels] = [ - torch.from_numpy(input[index].astype(np.int64)) - if indice < 5 - else torch.from_numpy(np.asarray(input[index].astype(np.int64))) - for indice, input in enumerate(self.inputs) - ] - - # HF model use default ignore_index value (-100) for CrossEntropyLoss - masked_lm_labels = torch.ones(input_ids.shape, dtype=torch.long) * -100 - index = self.max_pred_length - # store number of masked tokens in index - padded_mask_indices = (masked_lm_positions == 0).nonzero() - if len(padded_mask_indices) != 0: - index = padded_mask_indices[0].item() - masked_lm_labels[masked_lm_positions[:index]] = masked_lm_ids[:index] - return [input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels] - - -import argparse # noqa: E402 - - -def parse_arguments(): - parser = argparse.ArgumentParser() - - # batch size test config parameters - parser.add_argument( - "--enable_mixed_precision", - default=False, - action="store_true", - help="Whether to use 16-bit float precision instead of 32-bit", - ) - - parser.add_argument( - "--sequence_length", - default=512, - type=int, - help="The maximum total input sequence length after WordPiece tokenization. \n" - "Sequences longer than this will be truncated, and sequences shorter \n" - "than this will be padded.", - ) - parser.add_argument( - "--max_predictions_per_seq", default=80, type=int, help="The maximum total of masked tokens in input sequence" - ) - parser.add_argument("--max_batch_size", default=32, type=int, help="Total batch size for training.") - - parser.add_argument("--gelu_recompute", default=False, action="store_true") - - parser.add_argument("--attn_dropout_recompute", default=False, action="store_true") - - parser.add_argument("--transformer_layer_recompute", default=False, action="store_true") - - args = parser.parse_args() - return args - - -@dataclass -class PretrainArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. - """ - - input_dir: str = field( - default=None, metadata={"help": "The input data dir. Should contain .hdf5 files for the task"} - ) - - bert_model: str = field( - default=None, - metadata={ - "help": "Bert pre-trained model selected in the list: bert-base-uncased, \ - bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." - }, - ) - - output_dir: str = field( - default=None, metadata={"help": "The output directory where the model checkpoints will be written."} - ) - - cache_dir: str = field( - default="/tmp/bert_pretrain/", - metadata={"help": "The output directory where the model checkpoints will be written."}, - ) - max_seq_length: Optional[int] = field( - default=512, - metadata={ - "help": "The maximum total input sequence length after tokenization. Sequences longer \ - than this will be truncated, sequences shorter will be padded." - }, - ) - - max_predictions_per_seq: Optional[int] = field( - default=80, metadata={"help": "The maximum total of masked tokens in input sequence."} - ) - - train_batch_size: Optional[int] = field(default=32, metadata={"help": "Batch size for training."}) - - learning_rate: Optional[float] = field(default=5e-5, metadata={"help": "The initial learning rate for Lamb."}) - - num_train_epochs: Optional[float] = field( - default=3.0, metadata={"help": "Total number of training epochs to perform."} - ) - - max_steps: Optional[float] = field(default=1000, metadata={"help": "Total number of training steps to perform."}) - - warmup_proportion: Optional[float] = field( - default=0.01, - metadata={ - "help": "Proportion of training to perform linear learning rate warmup for. \ - E.g., 0.1 = 10%% of training." - }, - ) - - local_rank: Optional[int] = field(default=-1, metadata={"help": "local_rank for distributed training on gpus."}) - - world_rank: Optional[int] = field(default=-1) - - world_size: Optional[int] = field(default=1) - - seed: Optional[int] = field(default=42, metadata={"help": "random seed for initialization."}) - - gradient_accumulation_steps: Optional[int] = field( - default=1, metadata={"help": "Number of updates steps to accumualte before performing a backward/update pass."} - ) - - fp16: bool = field(default=False, metadata={"help": "Whether to use 16-bit float precision instead of 32-bit."}) - - gelu_recompute: bool = field( - default=False, metadata={"help": "Whether to enable recomputing Gelu activation output to save memory."} - ) - attn_dropout_recompute: bool = field( - default=False, metadata={"help": "Whether to enable recomputing attention dropout to save memory."} - ) - transformer_layer_recompute: bool = field( - default=False, metadata={"help": "Whether to enable recomputing transformer layerwise to save memory."} - ) - - loss_scale: Optional[float] = field( - default=0.0, metadata={"help": "Loss scaling, positive power of 2 values can improve fp16 convergence."} - ) - - deepspeed_zero_stage: Optional[int] = field(default=0, metadata={"help": "Deepspeed Zero Stage. 0 => disabled"}) - - log_freq: Optional[float] = field(default=1.0, metadata={"help": "frequency of logging loss."}) - - checkpoint_activations: bool = field(default=False, metadata={"help": "Whether to use gradient checkpointing."}) - - resume_from_checkpoint: bool = field( - default=False, metadata={"help": "Whether to resume training from checkpoint."} - ) - - resume_step: Optional[int] = field(default=-1, metadata={"help": "Step to resume training from."}) - - num_steps_per_checkpoint: Optional[int] = field( - default=100, metadata={"help": "Number of update steps until a model checkpoint is saved to disk."} - ) - - save_checkpoint: Optional[bool] = field( - default=False, metadata={"help": "Enable for saving a model checkpoint to disk."} - ) - - init_state_dict: Optional[dict] = field(default=None, metadata={"help": "State to load before training."}) - - phase2: bool = field(default=False, metadata={"help": "Whether to train with seq len 512."}) - - allreduce_post_accumulation: bool = field( - default=False, metadata={"help": "Whether to do allreduces during gradient accumulation steps."} - ) - - allreduce_post_accumulation_fp16: bool = field( - default=False, metadata={"help": "Whether to do fp16 allreduce post accumulation."} - ) - - accumulate_into_fp16: bool = field(default=False, metadata={"help": "Whether to use fp16 gradient accumulators."}) - - phase1_end_step: Optional[int] = field( - default=7038, metadata={"help": "Whether to use fp16 gradient accumulators."} - ) - - tensorboard_dir: Optional[str] = field( - default=None, - ) - - schedule: Optional[str] = field( - default="warmup_poly", - ) - - # this argument is test specific. to run a full bert model will take too long to run. instead, we reduce - # number of hidden layers so that it can show convergence to an extend to help detect any regression. - force_num_hidden_layers: Optional[int] = field( - default=None, metadata={"help": "Whether to use fp16 gradient accumulators."} - ) - - def to_json_string(self): - """ - Serializes this instance to a JSON string. - """ - return json.dumps(dataclasses.asdict(self), indent=2) - - def to_sanitized_dict(self) -> Dict[str, Any]: - """ - Sanitized serialization to use with TensorBoard`s hparams - """ - d = dataclasses.asdict(self) - valid_types = [bool, int, float, str, torch.Tensor] - return {k: v if type(v) in valid_types else str(v) for k, v in d.items()} - - -def setup_training(args): - assert torch.cuda.is_available() - - if args.local_rank == -1: - args.local_rank = 0 - args.world_rank = 0 - - print("args.local_rank: ", args.local_rank) - torch.cuda.set_device(args.local_rank) - device = torch.device("cuda", args.local_rank) - args.n_gpu = 1 - - if args.gradient_accumulation_steps < 1: - raise ValueError( - f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps}, should be >= 1" - ) - if args.train_batch_size % args.gradient_accumulation_steps != 0: - raise ValueError( - "Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible".format( - args.gradient_accumulation_steps, args.train_batch_size - ) - ) - - # args.train_batch_size is per global step (optimization step) batch size - # now make it a per gpu batch size - args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps - args.train_batch_size = args.train_batch_size // args.world_size - - logger.info("setup_training: args.train_batch_size = %d", args.train_batch_size) - return device, args - - -def setup_torch_distributed(world_rank, world_size): - os.environ["RANK"] = str(world_rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12345" - torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=world_rank) - return - - -def prepare_model(args, device): - config = BertConfig.from_pretrained(args.bert_model, cache_dir=args.cache_dir) - - # config.num_hidden_layers = 12 - if args.force_num_hidden_layers: - logger.info("Modifying model config with num_hidden_layers to %d", args.force_num_hidden_layers) - config.num_hidden_layers = args.force_num_hidden_layers - - model = BertForPreTraining(config) - if args.init_state_dict is not None: - model.load_state_dict(args.init_state_dict) - model_desc = bert_model_description(config) - - lr_scheduler = LinearWarmupLRScheduler(total_steps=int(args.max_steps), warmup=args.warmup_proportion) - - loss_scaler = amp.DynamicLossScaler() if args.fp16 else None - - options = orttrainer.ORTTrainerOptions( - { - "batch": {"gradient_accumulation_steps": args.gradient_accumulation_steps}, - "device": {"id": str(device)}, - "mixed_precision": {"enabled": args.fp16, "loss_scaler": loss_scaler}, - "graph_transformer": { - "attn_dropout_recompute": args.attn_dropout_recompute, - "gelu_recompute": args.gelu_recompute, - "transformer_layer_recompute": args.transformer_layer_recompute, - }, - "debug": { - "deterministic_compute": True, - }, - "utils": {"grad_norm_clip": True}, - "distributed": { - "world_rank": max(0, args.local_rank), - "world_size": args.world_size, - "local_rank": max(0, args.local_rank), - "allreduce_post_accumulation": args.allreduce_post_accumulation, - "deepspeed_zero_optimization": {"stage": args.deepspeed_zero_stage}, - "enable_adasum": False, - }, - "lr_scheduler": lr_scheduler, - } - ) - - param_optimizer = list(model.named_parameters()) - no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] - params = [ - { - "params": [n for n, p in param_optimizer if any(no_decay_key in n for no_decay_key in no_decay_keys)], - "alpha": 0.9, - "beta": 0.999, - "lambda": 0.0, - "epsilon": 1e-6, - }, - { - "params": [n for n, p in param_optimizer if not any(no_decay_key in n for no_decay_key in no_decay_keys)], - "alpha": 0.9, - "beta": 0.999, - "lambda": 0.0, - "epsilon": 1e-6, - }, - ] - - optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) - model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=options) - - return model - - -def get_data_file(f_id, world_rank, world_size, files): - num_files = len(files) - if world_size > num_files: - remainder = world_size % num_files - return files[(f_id * world_size + world_rank + remainder * f_id) % num_files] - elif world_size > 1: - return files[(f_id * world_size + world_rank) % num_files] - else: - return files[f_id % num_files] - - -def main(): - parser = HfArgumentParser(PretrainArguments) - args = parser.parse_args_into_dataclasses()[0] - do_pretrain(args) - - -def do_pretrain(args): - if is_main_process(args) and args.tensorboard_dir: - tb_writer = SummaryWriter(log_dir=args.tensorboard_dir) - tb_writer.add_text("args", args.to_json_string()) - tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={}) - else: - tb_writer = None - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - ort.set_seed(args.seed) - - device, args = setup_training(args) - - model = prepare_model(args, device) - - logger.info("Running training: Batch size = %d, initial LR = %f", args.train_batch_size, args.learning_rate) - - average_loss = 0.0 - epoch = 0 - training_steps = 0 - - pool = ProcessPoolExecutor(1) - while True: - files = [ - os.path.join(args.input_dir, f) - for f in os.listdir(args.input_dir) - if os.path.isfile(os.path.join(args.input_dir, f)) and "training" in f - ] - files.sort() - random.shuffle(files) - - f_id = 0 - train_dataloader, data_file = create_pretraining_dataset( - get_data_file(f_id, args.world_rank, args.world_size, files), args.max_predictions_per_seq, args - ) - - for f_id in range(1, len(files)): - logger.info("data file %s" % (data_file)) - - dataset_future = pool.submit( - create_pretraining_dataset, - get_data_file(f_id, args.world_rank, args.world_size, files), - args.max_predictions_per_seq, - args, - ) - - train_iter = tqdm(train_dataloader, desc="Iteration") if is_main_process(args) else train_dataloader - for _step, batch in enumerate(train_iter): - training_steps += 1 - batch = [t.to(device) for t in batch] # noqa: PLW2901 - input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch - - loss, _, _ = model.train_step( - input_ids, input_mask, segment_ids, masked_lm_labels, next_sentence_labels - ) - average_loss += loss.item() - - global_step = model._train_step_info.optimization_step - if training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0: - if is_main_process(args): - divisor = args.log_freq * args.gradient_accumulation_steps - if tb_writer: - lr = model.options.lr_scheduler.get_last_lr()[0] - tb_writer.add_scalar("train/summary/scalar/Learning_Rate", lr, global_step) - if args.fp16: - tb_writer.add_scalar("train/summary/scalar/loss_scale_25", loss, global_step) - # TODO: ORTTrainer to expose all_finite - # tb_writer.add_scalar('train/summary/scalar/all_fp16_gradients_finite_859', all_finite, global_step) - tb_writer.add_scalar("train/summary/total_loss", average_loss / divisor, global_step) - - print(f"Step:{global_step} Average Loss = {average_loss / divisor}") - - if global_step >= args.max_steps or global_step >= force_to_stop_max_steps: - if tb_writer: - tb_writer.close() - - if global_step >= args.max_steps: - if args.save_checkpoint: - model.save_checkpoint(os.path.join(args.output_dir, f"checkpoint-{args.world_rank}.ortcp")) - final_loss = average_loss / (args.log_freq * args.gradient_accumulation_steps) - return final_loss - - average_loss = 0 - - del train_dataloader - - train_dataloader, data_file = dataset_future.result(timeout=None) - - epoch += 1 - - -def generate_tensorboard_logdir(root_dir): - current_date_time = datetime.datetime.today() - - dt_string = current_date_time.strftime("BERT_pretrain_%y_%m_%d_%I_%M_%S") - return os.path.join(root_dir, dt_string) - - -class ORTBertPretrainTest(unittest.TestCase): - def setUp(self): - self.output_dir = "/bert_data/hf_data/test_out/bert_pretrain_results" - self.bert_model = "bert-base-uncased" - self.local_rank = -1 - self.world_rank = -1 - self.world_size = 1 - self.max_steps = 300000 - self.learning_rate = 5e-4 - self.max_seq_length = 512 - self.max_predictions_per_seq = 20 - self.input_dir = "/bert_data/hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train" - self.train_batch_size = 4096 - self.gradient_accumulation_steps = 64 - self.fp16 = True - self.allreduce_post_accumulation = True - self.tensorboard_dir = "/bert_data/hf_data/test_out" - - def test_pretrain_throughput(self, process_args=None): - if process_args.sequence_length == 128: - input_dir = "/bert_data/hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train" - else: - input_dir = "/bert_data/hdf5_lower_case_1_seq_len_512_max_pred_80_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train" - - print("process_args.enable_mixed_precision: ", process_args.enable_mixed_precision) - print("process_args.sequence_length: ", process_args.sequence_length) - print("process_args.max_batch_size: ", process_args.max_batch_size) - print("process_args.max_predictions_per_seq: ", process_args.max_predictions_per_seq) - print("process_args.gelu_recompute: ", process_args.gelu_recompute) - print("process_args.attn_dropout_recompute: ", process_args.attn_dropout_recompute) - print("process_args.transformer_layer_recompute: ", process_args.transformer_layer_recompute) - - args = PretrainArguments( - input_dir=input_dir, - output_dir="/bert_data/hf_data/test_out/bert_pretrain_results", - bert_model="bert-large-uncased", - local_rank=self.local_rank, - world_rank=self.world_rank, - world_size=self.world_size, - max_steps=10, - learning_rate=5e-4, - max_seq_length=process_args.sequence_length, - max_predictions_per_seq=process_args.max_predictions_per_seq, - train_batch_size=process_args.max_batch_size, - gradient_accumulation_steps=1, - fp16=process_args.enable_mixed_precision, - gelu_recompute=process_args.gelu_recompute, - attn_dropout_recompute=process_args.attn_dropout_recompute, - transformer_layer_recompute=process_args.transformer_layer_recompute, - allreduce_post_accumulation=True, - # TODO: remove - force_num_hidden_layers=2, - ) - do_pretrain(args) - - def test_pretrain_convergence(self): - args = PretrainArguments( - output_dir=self.output_dir, - bert_model=self.bert_model, - local_rank=self.local_rank, - world_rank=self.world_rank, - world_size=self.world_size, - max_steps=self.max_steps, - learning_rate=self.learning_rate, - max_seq_length=self.max_seq_length, - max_predictions_per_seq=self.max_predictions_per_seq, - train_batch_size=self.train_batch_size, - gradient_accumulation_steps=self.gradient_accumulation_steps, - input_dir=self.input_dir, - fp16=self.fp16, - allreduce_post_accumulation=self.allreduce_post_accumulation, - force_num_hidden_layers=self.force_num_hidden_layers, - tensorboard_dir=generate_tensorboard_logdir("/bert_data/hf_data/test_out/"), - ) - final_loss = do_pretrain(args) - return final_loss - - def test_pretrain_zero(self): - assert self.world_size > 0, "ZeRO test requires a distributed run." - setup_torch_distributed(self.world_rank, self.world_size) - per_gpu_batch_size = 32 - optimization_batch_size = per_gpu_batch_size * self.world_size # set to disable grad accumulation - - self.train_batch_size = optimization_batch_size - self.gradient_accumulation_steps = 1 - self.deepspeed_zero_stage = 1 - self.force_num_hidden_layers = 2 - self.max_seq_length = 32 - self.output_dir = "./bert_pretrain_ckpt" - if self.world_rank == 0: - if os.path.isdir(self.output_dir): - shutil.rmtree(self.output_dir) - os.makedirs(self.output_dir, exist_ok=True) - - torch.distributed.barrier() - - assert os.path.exists(self.output_dir) - - # run a few optimization steps - self.max_steps = 200 - args = PretrainArguments( - output_dir=self.output_dir, - bert_model=self.bert_model, - local_rank=self.local_rank, - world_rank=self.world_rank, - world_size=self.world_size, - max_steps=self.max_steps, - learning_rate=self.learning_rate, - max_seq_length=self.max_seq_length, - max_predictions_per_seq=self.max_predictions_per_seq, - train_batch_size=self.train_batch_size, - gradient_accumulation_steps=self.gradient_accumulation_steps, - input_dir=self.input_dir, - fp16=self.fp16, - allreduce_post_accumulation=self.allreduce_post_accumulation, - force_num_hidden_layers=self.force_num_hidden_layers, - deepspeed_zero_stage=self.deepspeed_zero_stage, - save_checkpoint=True, - ) - do_pretrain(args) - - # ensure all workers reach this point before loading the checkpointed state - torch.distributed.barrier() - - # on rank 0, load the trained state - if args.world_rank == 0: - checkpoint_files = glob.glob(os.path.join(self.output_dir, "checkpoint*.ortcp")) - args.init_state_dict = aggregate_checkpoints(checkpoint_files, pytorch_format=True) - - torch.distributed.barrier() - - # run a single step to get the loss, on rank 0 should be lesser than starting loss - args.save_checkpoint = False - args.max_steps = 1 - args.deepspeed_zero_stage = 0 - final_loss = do_pretrain(args) - return final_loss - - -if __name__ == "__main__": - import sys - - logger.warning("sys.argv: %s", sys.argv) - # usage: - # data parallel training - # mpirun -n 4 python orttraining_run_bert_pretrain.py - # - # single gpu: - # python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_throughput - # [batch size test arguments] - # python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_convergence - # - # pytorch.distributed.launch will not work because ort backend requires MPI to broadcast ncclUniqueId - # calling unpublished get_mpi_context_xxx to get rank/size numbers. - try: - # In case ORT is not built with MPI/NCCL, there are no get_mpi_context_xxx internal apis. - from onnxruntime.capi._pybind_state import get_mpi_context_local_size # noqa: F401 - from onnxruntime.capi._pybind_state import get_mpi_context_world_rank # noqa: F401 - from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_world_size - - has_get_mpi_context_internal_api = True - except ImportError: - has_get_mpi_context_internal_api = False - pass - if has_get_mpi_context_internal_api and get_mpi_context_world_size() > 1: - world_size = get_mpi_context_world_size() - print("get_mpi_context_world_size(): ", world_size) - local_rank = get_mpi_context_local_rank() - - if local_rank == 0: - print("================================================================> os.getpid() = ", os.getpid()) - - test = ORTBertPretrainTest() - test.setUp() - test.local_rank = local_rank - test.world_rank = local_rank - test.world_size = world_size - - if len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_zero": - logger.info("running ORTBertPretrainTest.test_pretrain_zero()...") - final_loss = test.test_pretrain_zero() - logger.info("ORTBertPretrainTest.test_pretrain_zero() rank = %i final loss = %f", local_rank, final_loss) - if local_rank == 0: - test.assertLess(final_loss, 10.2) - else: - test.assertGreater(final_loss, 11.0) - logger.info("ORTBertPretrainTest.test_pretrain_zero() passed") - elif len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_convergence": - logger.info("running ORTBertPretrainTest.test_pretrain_convergence()...") - test.max_steps = 200 - test.force_num_hidden_layers = 8 - final_loss = test.test_pretrain_convergence() - logger.info("ORTBertPretrainTest.test_pretrain_convergence() final loss = %f", final_loss) - test.assertLess(final_loss, 8.5) - logger.info("ORTBertPretrainTest.test_pretrain_convergence() passed") - else: - # https://microsoft.sharepoint.com/teams/ONNX2/_layouts/15/Doc.aspx?sourcedoc={170774be-e1c6-4f8b-a3ae-984f211fe410}&action=edit&wd=target%28ONNX%20Training.one%7C8176133b-c7cb-4ef2-aa9d-3fdad5344c40%2FGitHub%20Master%20Merge%20Schedule%7Cb67f0db1-e3a0-4add-80a6-621d67fd8107%2F%29 - # to make equivalent args for cpp convergence test - test.max_seq_length = 128 - test.max_predictions_per_seq = 20 - test.gradient_accumulation_steps = 16 - - # cpp_batch_size (=64) * grad_acc * world_size - test.train_batch_size = 64 * test.gradient_accumulation_steps * test.world_size - test.max_steps = 300000 - - test.force_num_hidden_layers = None - - # already using Adam (e.g. AdamConfig) - test.learning_rate = 5e-4 - test.warmup_proportion = 0.1 - - final_loss = test.test_pretrain_convergence() - logger.info("ORTBertPretrainTest.test_pretrain_convergence() final loss = %f", final_loss) - else: - # unittest does not accept user defined arguments - # we need to run this script with user defined arguments - if len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_throughput": - run_test_pretrain_throughput, run_test_pretrain_convergence = True, False - sys.argv.remove("ORTBertPretrainTest.test_pretrain_throughput") - elif len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_convergence": - run_test_pretrain_throughput, run_test_pretrain_convergence = False, True - sys.argv.remove("ORTBertPretrainTest.test_pretrain_convergence") - else: - run_test_pretrain_throughput, run_test_pretrain_convergence = True, True - process_args = parse_arguments() - test = ORTBertPretrainTest() - test.setUp() - - if run_test_pretrain_throughput: - logger.info("running single GPU ORTBertPretrainTest.test_pretrain_throughput()...") - test.test_pretrain_throughput(process_args) - logger.info("single GPU ORTBertPretrainTest.test_pretrain_throughput() passed") - - # unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_run_frontend_batch_size_test.py b/orttraining/orttraining/test/python/orttraining_run_frontend_batch_size_test.py deleted file mode 100644 index 3e2d1a7154bfd..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_run_frontend_batch_size_test.py +++ /dev/null @@ -1,67 +0,0 @@ -import collections -import subprocess -import sys - -Config = collections.namedtuple( - "Config", - [ - "enable_mixed_precision", - "sequence_length", - "max_batch_size", - "max_predictions_per_seq", - "gelu_recompute", - "attn_dropout_recompute", - "transformer_layer_recompute", - ], -) - -configs = [ - Config(True, 128, 46, 20, False, False, False), - Config(True, 512, 8, 80, False, False, False), - Config(False, 128, 26, 20, False, False, False), - Config(False, 512, 4, 80, False, False, False), - Config(True, 128, 50, 20, True, False, False), - Config(True, 128, 50, 20, False, True, False), - Config(True, 128, 76, 20, False, False, True), - Config(True, 512, 8, 80, True, False, False), - Config(True, 512, 9, 80, False, True, False), - Config(True, 512, 15, 80, False, False, True), -] - - -def run_with_config(config): - print( - "##### testing name - {}-{} #####".format( - "fp16" if config.enable_mixed_precision else "fp32", config.sequence_length - ) - ) - print("gelu_recompute: ", config.gelu_recompute) - print("attn_dropout_recompute: ", config.attn_dropout_recompute) - print("transformer_layer_recompute: ", config.transformer_layer_recompute) - - cmds = [ - sys.executable, - "orttraining_run_bert_pretrain.py", - "ORTBertPretrainTest.test_pretrain_throughput", - "--sequence_length", - str(config.sequence_length), - "--max_batch_size", - str(config.max_batch_size), - "--max_predictions_per_seq", - str(config.max_predictions_per_seq), - ] - if config.enable_mixed_precision: - cmds.append("--enable_mixed_precision") - if config.gelu_recompute: - cmds.append("--gelu_recompute") - if config.attn_dropout_recompute: - cmds.append("--attn_dropout_recompute") - if config.transformer_layer_recompute: - cmds.append("--transformer_layer_recompute") - - # access to azure storage shared disk is much slower so we need a longer timeout. - subprocess.run(cmds, timeout=1200).check_returncode() # noqa: PLW1510 - - -for config in configs: - run_with_config(config) diff --git a/orttraining/orttraining/test/python/orttraining_run_glue.py b/orttraining/orttraining/test/python/orttraining_run_glue.py deleted file mode 100644 index 794e2f8cc7240..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_run_glue.py +++ /dev/null @@ -1,323 +0,0 @@ -# adapted from run_glue.py of huggingface transformers - -import dataclasses # noqa: F401 -import logging -import os -import unittest -from dataclasses import dataclass, field -from typing import Dict, Optional - -import numpy as np -from numpy.testing import assert_allclose -from transformers import ( - AutoConfig, - AutoModelForSequenceClassification, - AutoTokenizer, - EvalPrediction, - GlueDataset, - GlueDataTrainingArguments, - TrainingArguments, - glue_compute_metrics, - glue_output_modes, - glue_tasks_num_labels, - set_seed, -) - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer # noqa: F401 - -try: - from onnxruntime.capi._pybind_state import get_mpi_context_local_size # noqa: F401 - from onnxruntime.capi._pybind_state import get_mpi_context_world_rank # noqa: F401 - from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_world_size - - has_get_mpi_context_internal_api = True -except ImportError: - has_get_mpi_context_internal_api = False - pass - - -import torch # noqa: F401 -from orttraining_transformer_trainer import ORTTransformerTrainer - -logger = logging.getLogger(__name__) - - -def verify_old_and_new_api_are_equal(results_per_api): - new_api_results = results_per_api[True] - old_api_results = results_per_api[False] - for key in new_api_results: - assert_allclose(new_api_results[key], old_api_results[key]) - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. - """ - - model_name_or_path: str = field(metadata={"help": "model identifier from huggingface.co/models"}) - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - cache_dir: Optional[str] = field( - default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} - ) - - -class ORTGlueTest(unittest.TestCase): - def setUp(self): - # configurations not to be changed accoss tests - self.max_seq_length = 128 - self.train_batch_size = 8 - self.learning_rate = 2e-5 - self.num_train_epochs = 3.0 - self.local_rank = -1 - self.world_size = 1 - self.overwrite_output_dir = True - self.gradient_accumulation_steps = 1 - self.data_dir = "/bert_data/hf_data/glue_data/" - self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "glue_test_output/") - self.cache_dir = "/tmp/glue/" - self.logging_steps = 10 - - def test_roberta_with_mrpc(self): - expected_acc = 0.85 - expected_f1 = 0.88 - expected_loss = 0.35 - results = self.run_glue(model_name="roberta-base", task_name="MRPC", fp16=False) - - assert results["acc"] >= expected_acc - assert results["f1"] >= expected_f1 - assert results["loss"] <= expected_loss - - def test_roberta_fp16_with_mrpc(self): - expected_acc = 0.87 - expected_f1 = 0.90 - expected_loss = 0.33 - - results = self.run_glue(model_name="roberta-base", task_name="MRPC", fp16=True) - - assert results["acc"] >= expected_acc - assert results["f1"] >= expected_f1 - assert results["loss"] <= expected_loss - - def test_bert_with_mrpc(self): - if self.local_rank == -1: - expected_acc = 0.83 - expected_f1 = 0.88 - expected_loss = 0.44 - elif self.local_rank == 0: - expected_acc = 0.81 - expected_f1 = 0.86 - expected_loss = 0.44 - - results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=False) - - if self.local_rank in [-1, 0]: - assert results["acc"] >= expected_acc - assert results["f1"] >= expected_f1 - assert results["loss"] <= expected_loss - - def test_bert_fp16_with_mrpc(self): - expected_acc = 0.84 - expected_f1 = 0.88 - expected_loss = 0.46 - - results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=True) - - assert results["acc"] >= expected_acc - assert results["f1"] >= expected_f1 - assert results["loss"] <= expected_loss - - def model_to_desc(self, model_name, model): - if model_name.startswith("bert") or model_name.startswith("xlnet"): - model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", "max_seq_len_in_batch"], - ), - ( - "token_type_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "labels", - [ - "batch", - ], - ), - ], - "outputs": [("loss", [], True), ("logits", ["batch", 2])], - } - elif model_name.startswith("roberta"): - model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", "max_seq_len_in_batch"], - ), - ( - "labels", - [ - "batch", - ], - ), - ], - "outputs": [("loss", [], True), ("logits", ["batch", 2])], - } - else: - raise RuntimeError(f"unsupported base model name {model_name}.") - - return model_desc - - def run_glue(self, model_name, task_name, fp16): - model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir) - data_args = GlueDataTrainingArguments( - task_name=task_name, data_dir=os.path.join(self.data_dir, task_name), max_seq_length=self.max_seq_length - ) - - training_args = TrainingArguments( - output_dir=os.path.join(self.output_dir, task_name), - do_train=True, - do_eval=True, - per_gpu_train_batch_size=self.train_batch_size, - learning_rate=self.learning_rate, - num_train_epochs=self.num_train_epochs, - local_rank=self.local_rank, - overwrite_output_dir=self.overwrite_output_dir, - gradient_accumulation_steps=self.gradient_accumulation_steps, - fp16=fp16, - logging_steps=self.logging_steps, - ) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, - ) - logger.warning( - "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", - training_args.local_rank, - training_args.device, - training_args.n_gpu, - bool(training_args.local_rank != -1), - training_args.fp16, - ) - logger.info("Training/evaluation parameters %s", training_args) - - set_seed(training_args.seed) - onnxruntime.set_seed(training_args.seed) - - try: - num_labels = glue_tasks_num_labels[data_args.task_name] - output_mode = glue_output_modes[data_args.task_name] - except KeyError: - raise ValueError("Task not found: %s" % (data_args.task_name)) # noqa: B904 - - config = AutoConfig.from_pretrained( - model_args.config_name if model_args.config_name else model_args.model_name_or_path, - num_labels=num_labels, - finetuning_task=data_args.task_name, - cache_dir=model_args.cache_dir, - ) - tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - ) - - model = AutoModelForSequenceClassification.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - ) - - train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None - - eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") if training_args.do_eval else None - - def compute_metrics(p: EvalPrediction) -> Dict: - if output_mode == "classification": - preds = np.argmax(p.predictions, axis=1) - elif output_mode == "regression": - preds = np.squeeze(p.predictions) - return glue_compute_metrics(data_args.task_name, preds, p.label_ids) - - model_desc = self.model_to_desc(model_name, model) - # Initialize the ORTTrainer within ORTTransformerTrainer - trainer = ORTTransformerTrainer( - model=model, - model_desc=model_desc, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - compute_metrics=compute_metrics, - world_size=self.world_size, - ) - - # Training - if training_args.do_train: - trainer.train() - trainer.save_model() - - # Evaluation - results = {} - if training_args.do_eval and training_args.local_rank in [-1, 0]: - logger.info("*** Evaluate ***") - - result = trainer.evaluate() - - logger.info(f"***** Eval results {data_args.task_name} *****") - for key, value in result.items(): - logger.info(" %s = %s", key, value) - - results.update(result) - - return results - - -if __name__ == "__main__": - if has_get_mpi_context_internal_api: - local_rank = get_mpi_context_local_rank() - world_size = get_mpi_context_world_size() - else: - local_rank = -1 - world_size = 1 - - if world_size > 1: - # mpi launch - logger.warning("mpirun launch, local_rank / world_size: %s : % s", local_rank, world_size) - - # TrainingArguments._setup_devices will call torch.distributed.init_process_group(backend="nccl") - # pytorch expects following environment settings (which would be set if launched with torch.distributed.launch). - - os.environ["RANK"] = str(local_rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = "29500" - - from onnxruntime.capi._pybind_state import set_cuda_device_id - - set_cuda_device_id(local_rank) - - test = ORTGlueTest() - test.setUp() - test.local_rank = local_rank - test.world_size = world_size - test.test_bert_with_mrpc() - else: - unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py b/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py deleted file mode 100644 index 92db204593bcd..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py +++ /dev/null @@ -1,281 +0,0 @@ -# adapted from run_multiple_choice.py of huggingface transformers -# https://github.com/huggingface/transformers/blob/master/examples/multiple-choice/run_multiple_choice.py - -import dataclasses # noqa: F401 -import logging -import os -import unittest -from dataclasses import dataclass, field -from typing import Dict, Optional - -import numpy as np -import torch # noqa: F401 -from numpy.testing import assert_allclose # noqa: F401 -from orttraining_run_glue import verify_old_and_new_api_are_equal # noqa: F401 -from orttraining_transformer_trainer import ORTTransformerTrainer -from transformers import HfArgumentParser # noqa: F401 -from transformers import Trainer # noqa: F401 -from transformers import ( - AutoConfig, - AutoModelForMultipleChoice, - AutoTokenizer, - EvalPrediction, - TrainingArguments, - set_seed, -) -from utils_multiple_choice import MultipleChoiceDataset, Split, SwagProcessor - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer # noqa: F401 - -logger = logging.getLogger(__name__) - - -def simple_accuracy(preds, labels): - return (preds == labels).mean() - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. - """ - - model_name_or_path: str = field(metadata={"help": "model identifier from huggingface.co/models"}) - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - cache_dir: Optional[str] = field( - default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} - ) - - -@dataclass -class DataTrainingArguments: - """ - Arguments pertaining to what data we are going to input our model for training and eval. - """ - - task_name: str = field(metadata={"help": "The name of the task to train on."}) - data_dir: str = field(metadata={"help": "Should contain the data files for the task."}) - max_seq_length: int = field( - default=128, - metadata={ - "help": "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded." - }, - ) - overwrite_cache: bool = field(default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}) - - -class ORTMultipleChoiceTest(unittest.TestCase): - def setUp(self): - # configurations not to be changed accoss tests - self.max_seq_length = 80 - self.train_batch_size = 16 - self.eval_batch_size = 2 - self.learning_rate = 2e-5 - self.num_train_epochs = 1.0 - self.local_rank = -1 - self.overwrite_output_dir = True - self.gradient_accumulation_steps = 8 - self.data_dir = "/bert_data/hf_data/swag/swagaf/data" - self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "multiple_choice_test_output/") - self.cache_dir = "/tmp/multiple_choice/" - self.logging_steps = 10 - self.rtol = 2e-01 - - def test_bert_with_swag(self): - expected_acc = 0.75 - expected_loss = 0.64 - - results = self.run_multiple_choice(model_name="bert-base-cased", task_name="swag", fp16=False) - assert results["acc"] >= expected_acc - assert results["loss"] <= expected_loss - - def test_bert_fp16_with_swag(self): - # larger batch can be handled with mixed precision - self.train_batch_size = 32 - - expected_acc = 0.73 - expected_loss = 0.68 - - results = self.run_multiple_choice(model_name="bert-base-cased", task_name="swag", fp16=True) - assert results["acc"] >= expected_acc - assert results["loss"] <= expected_loss - - def run_multiple_choice(self, model_name, task_name, fp16): - model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir) - data_args = DataTrainingArguments( - task_name=task_name, data_dir=self.data_dir, max_seq_length=self.max_seq_length - ) - - training_args = TrainingArguments( - output_dir=os.path.join(self.output_dir, task_name), - do_train=True, - do_eval=True, - per_gpu_train_batch_size=self.train_batch_size, - per_gpu_eval_batch_size=self.eval_batch_size, - learning_rate=self.learning_rate, - num_train_epochs=self.num_train_epochs, - local_rank=self.local_rank, - overwrite_output_dir=self.overwrite_output_dir, - gradient_accumulation_steps=self.gradient_accumulation_steps, - fp16=fp16, - logging_steps=self.logging_steps, - ) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, - ) - logger.warning( - "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", - training_args.local_rank, - training_args.device, - training_args.n_gpu, - bool(training_args.local_rank != -1), - training_args.fp16, - ) - logger.info("Training/evaluation parameters %s", training_args) - - set_seed(training_args.seed) - onnxruntime.set_seed(training_args.seed) - - try: - processor = SwagProcessor() - label_list = processor.get_labels() - num_labels = len(label_list) - except KeyError: - raise ValueError("Task not found: %s" % (data_args.task_name)) # noqa: B904 - - config = AutoConfig.from_pretrained( - model_args.config_name if model_args.config_name else model_args.model_name_or_path, - num_labels=num_labels, - finetuning_task=data_args.task_name, - cache_dir=model_args.cache_dir, - ) - - tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - ) - - model = AutoModelForMultipleChoice.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - ) - - # Get datasets - train_dataset = ( - MultipleChoiceDataset( - data_dir=data_args.data_dir, - tokenizer=tokenizer, - task=data_args.task_name, - processor=processor, - max_seq_length=data_args.max_seq_length, - overwrite_cache=data_args.overwrite_cache, - mode=Split.train, - ) - if training_args.do_train - else None - ) - eval_dataset = ( - MultipleChoiceDataset( - data_dir=data_args.data_dir, - tokenizer=tokenizer, - task=data_args.task_name, - processor=processor, - max_seq_length=data_args.max_seq_length, - overwrite_cache=data_args.overwrite_cache, - mode=Split.dev, - ) - if training_args.do_eval - else None - ) - - def compute_metrics(p: EvalPrediction) -> Dict: - preds = np.argmax(p.predictions, axis=1) - return {"acc": simple_accuracy(preds, p.label_ids)} - - if model_name.startswith("bert"): - model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "token_type_ids", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "labels", - ["batch", num_labels], - ), - ], - "outputs": [("loss", [], True), ("reshaped_logits", ["batch", num_labels])], - } - else: - model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "labels", - ["batch", num_labels], - ), - ], - "outputs": [("loss", [], True), ("reshaped_logits", ["batch", num_labels])], - } - - # Initialize the ORTTrainer within ORTTransformerTrainer - trainer = ORTTransformerTrainer( - model=model, - model_desc=model_desc, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - compute_metrics=compute_metrics, - ) - - # Training - if training_args.do_train: - trainer.train() - trainer.save_model() - - # Evaluation - results = {} - if training_args.do_eval and training_args.local_rank in [-1, 0]: - logger.info("*** Evaluate ***") - - result = trainer.evaluate() - - logger.info(f"***** Eval results {data_args.task_name} *****") - for key, value in result.items(): - logger.info(" %s = %s", key, value) - - results.update(result) - - return results - - -if __name__ == "__main__": - unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py b/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py deleted file mode 100644 index 71e6bb8e4d2f2..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py +++ /dev/null @@ -1,6 +0,0 @@ -from orttraining_test_layer_norm_transform import layer_norm_transform # noqa: F401 -from orttraining_test_model_transform import add_expand_shape, add_name, fix_transpose # noqa: F401 - - -def postprocess_model(model): - add_name(model) diff --git a/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py b/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py deleted file mode 100644 index 21372caaf6779..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# orttraining_test_checkpoint_storage.py - -import os -import pickle -import shutil - -import numpy as np -import pytest -import torch - -from onnxruntime.training import _checkpoint_storage - -# Helper functions - - -def _equals(a, b): - """Checks recursively if two dictionaries are equal""" - if isinstance(a, dict): - return all(not (key not in b or not _equals(a[key], b[key])) for key in a) - else: - if isinstance(a, bytes): - a = a.decode() - if isinstance(b, bytes): - b = b.decode() - are_equal = a == b - return are_equal if isinstance(are_equal, bool) else are_equal.all() - - return False - - -def _numpy_types(obj_value): - """Return a bool indicating whether or not the input obj_value is a numpy type object - - Recursively checks if the obj_value (could be a dictionary) is a numpy type object. - Exceptions are str and bytes. - - Returns true if object is numpy type, str, or bytes - False if any other type - """ - if not isinstance(obj_value, dict): - return isinstance(obj_value, (str, bytes)) or type(obj_value).__module__ == np.__name__ - - return all(_numpy_types(value) for _, value in obj_value.items()) - - -def _get_dict(separated_key): - """Create dummy dictionary with different datatypes - - Returns the tuple of the entire dummy dictionary created, key argument as a dictionary for _checkpoint_storage.load - function and the value for that key in the original dictionary - - For example the complete dictionary is represented by: - { - 'int1':1, - 'int2': 2, - 'int_list': [1,2,3,5,6], - 'dict1': { - 'np_array': np.arange(100), - 'dict2': {'int3': 3, 'int4': 4}, - 'str1': "onnxruntime" - }, - 'bool1': bool(True), - 'int5': 5, - 'float1': 2.345, - 'np_array_float': np.array([1.234, 2.345, 3.456]), - 'np_array_float_3_dim': np.array([[[1,2],[3,4]], [[5,6],[7,8]]]) - } - - if the input key is ['dict1', 'str1'], then the key argument returned is 'dict1/str1' - and the value corresponding to that is "onnxruntime" - - so, for the above example, the returned tuple is: - (original_dict, {'key': 'dict1/str1', "onnxruntime") - """ - test_dict = { - "int1": 1, - "int2": 2, - "int_list": [1, 2, 3, 5, 6], - "dict1": {"np_array": np.arange(100), "dict2": {"int3": 3, "int4": 4}, "str1": "onnxruntime"}, - "bool1": True, - "int5": 5, - "float1": 2.345, - "np_array_float": np.array([1.234, 2.345, 3.456]), - "np_array_float_3_dim": np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]), - } - key = "" - expected_val = test_dict - for single_key in separated_key: - key += single_key + "/" - expected_val = expected_val[single_key] - return test_dict, {"key": key} if len(separated_key) > 0 else dict(), expected_val - - -class _CustomClass: - """Custom object that encpsulates dummy values for loss, epoch and train_step""" - - def __init__(self): - self._loss = 1.23 - self._epoch = 12000 - self._train_step = 25 - - def __eq__(self, other): - if isinstance(other, _CustomClass): - return self._loss == other._loss and self._epoch == other._epoch and self._train_step == other._train_step - - -# Test fixtures - - -@pytest.yield_fixture(scope="function") -def checkpoint_storage_test_setup(): - checkpoint_dir = os.path.abspath("checkpoint_dir/") - if not os.path.exists(checkpoint_dir): - os.makedirs(checkpoint_dir, exist_ok=True) - pytest.checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.ortcp") - yield "checkpoint_storage_test_setup" - shutil.rmtree(checkpoint_dir) - - -@pytest.yield_fixture(scope="function") -def checkpoint_storage_test_parameterized_setup(request, checkpoint_storage_test_setup): - yield request.param - - -# Tests - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", - [ - _get_dict([]), - _get_dict(["int1"]), - _get_dict(["dict1"]), - _get_dict(["dict1", "dict2"]), - _get_dict(["dict1", "dict2", "int4"]), - _get_dict(["dict1", "str1"]), - _get_dict(["bool1"]), - _get_dict(["float1"]), - _get_dict(["np_array_float"]), - ], - indirect=True, -) -def test_checkpoint_storage_saved_dict_matches_loaded(checkpoint_storage_test_parameterized_setup): - to_save = checkpoint_storage_test_parameterized_setup[0] - key_arg = checkpoint_storage_test_parameterized_setup[1] - expected = checkpoint_storage_test_parameterized_setup[2] - _checkpoint_storage.save(to_save, pytest.checkpoint_path) - loaded = _checkpoint_storage.load(pytest.checkpoint_path, **key_arg) - assert _equals(loaded, expected) - assert _numpy_types(loaded) - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", - [{"int_set": {1, 2, 3, 4, 5}}, {"str_set": {"one", "two"}}, [1, 2, 3], 2.352], - indirect=True, -) -def test_checkpoint_storage_saving_non_supported_types_fails(checkpoint_storage_test_parameterized_setup): - to_save = checkpoint_storage_test_parameterized_setup - with pytest.raises(Exception): # noqa: B017 - _checkpoint_storage.save(to_save, pytest.checkpoint_path) - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", - [ - ({"int64_tensor": torch.tensor(np.arange(100))}, "int64_tensor", torch.int64, np.int64), - ({"int32_tensor": torch.tensor(np.arange(100), dtype=torch.int32)}, "int32_tensor", torch.int32, np.int32), - ({"int16_tensor": torch.tensor(np.arange(100), dtype=torch.int16)}, "int16_tensor", torch.int16, np.int16), - ({"int8_tensor": torch.tensor(np.arange(100), dtype=torch.int8)}, "int8_tensor", torch.int8, np.int8), - ({"float64_tensor": torch.tensor(np.array([1.0, 2.0]))}, "float64_tensor", torch.float64, np.float64), - ( - {"float32_tensor": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32)}, - "float32_tensor", - torch.float32, - np.float32, - ), - ( - {"float16_tensor": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float16)}, - "float16_tensor", - torch.float16, - np.float16, - ), - ], - indirect=True, -) -def test_checkpoint_storage_saving_tensor_datatype(checkpoint_storage_test_parameterized_setup): - tensor_dict = checkpoint_storage_test_parameterized_setup[0] - tensor_name = checkpoint_storage_test_parameterized_setup[1] - tensor_dtype = checkpoint_storage_test_parameterized_setup[2] - np_dtype = checkpoint_storage_test_parameterized_setup[3] - - _checkpoint_storage.save(tensor_dict, pytest.checkpoint_path) - - loaded = _checkpoint_storage.load(pytest.checkpoint_path) - assert isinstance(loaded[tensor_name], np.ndarray) - assert tensor_dict[tensor_name].dtype == tensor_dtype - assert loaded[tensor_name].dtype == np_dtype - assert (tensor_dict[tensor_name].numpy() == loaded[tensor_name]).all() - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", - [ - ({"two_dim": torch.ones([2, 4], dtype=torch.float64)}, "two_dim"), - ({"three_dim": torch.ones([2, 4, 6], dtype=torch.float64)}, "three_dim"), - ({"four_dim": torch.ones([2, 4, 6, 8], dtype=torch.float64)}, "four_dim"), - ], - indirect=True, -) -def test_checkpoint_storage_saving_multiple_dimension_tensors(checkpoint_storage_test_parameterized_setup): - tensor_dict = checkpoint_storage_test_parameterized_setup[0] - tensor_name = checkpoint_storage_test_parameterized_setup[1] - - _checkpoint_storage.save(tensor_dict, pytest.checkpoint_path) - - loaded = _checkpoint_storage.load(pytest.checkpoint_path) - assert isinstance(loaded[tensor_name], np.ndarray) - assert (tensor_dict[tensor_name].numpy() == loaded[tensor_name]).all() - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", [{}, {"a": {}}, {"a": {"b": {}}}], indirect=True -) -def test_checkpoint_storage_saving_and_loading_empty_dictionaries_succeeds(checkpoint_storage_test_parameterized_setup): - saved = checkpoint_storage_test_parameterized_setup - _checkpoint_storage.save(saved, pytest.checkpoint_path) - - loaded = _checkpoint_storage.load(pytest.checkpoint_path) - assert _equals(saved, loaded) - - -def test_checkpoint_storage_load_file_that_does_not_exist_fails(checkpoint_storage_test_setup): - with pytest.raises(Exception): # noqa: B017 - _checkpoint_storage.load(pytest.checkpoint_path) - - -def test_checkpoint_storage_for_custom_user_dict_succeeds(checkpoint_storage_test_setup): - custom_class = _CustomClass() - user_dict = {"tensor1": torch.tensor(np.arange(100), dtype=torch.float32), "custom_class": custom_class} - - pickled_bytes = pickle.dumps(user_dict).hex() - to_save = {"a": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32), "user_dict": pickled_bytes} - _checkpoint_storage.save(to_save, pytest.checkpoint_path) - - loaded_dict = _checkpoint_storage.load(pytest.checkpoint_path) - assert (loaded_dict["a"] == to_save["a"].numpy()).all() - try: # noqa: SIM105 - loaded_dict["user_dict"] = loaded_dict["user_dict"].decode() - except AttributeError: - pass - loaded_obj = pickle.loads(bytes.fromhex(loaded_dict["user_dict"])) - - assert torch.all(loaded_obj["tensor1"].eq(user_dict["tensor1"])) - assert loaded_obj["custom_class"] == custom_class diff --git a/orttraining/orttraining/test/python/orttraining_test_data_loader.py b/orttraining/orttraining/test/python/orttraining_test_data_loader.py index aa15b44ae0d66..0009d2d3d7e1b 100644 --- a/orttraining/orttraining/test/python/orttraining_test_data_loader.py +++ b/orttraining/orttraining/test/python/orttraining_test_data_loader.py @@ -4,8 +4,6 @@ import torch from torch.utils.data import DataLoader, Dataset -from onnxruntime.capi.ort_trainer import generate_sample - global_rng = random.Random() @@ -41,6 +39,16 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None): return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() +def generate_sample(desc, device=None): + """Generate a sample based on the description""" + # symbolic dimensions are described with strings. set symbolic dimensions to be 1 + size = [s if isinstance(s, (int)) else 1 for s in desc.shape_] + if desc.num_classes_: + return torch.randint(0, desc.num_classes_, size, dtype=desc.dtype_).to(device) + else: + return torch.randn(size, dtype=desc.dtype_).to(device) + + class OrtTestDataset(Dataset): def __init__(self, input_desc, seq_len, dataset_len, device): import copy diff --git a/orttraining/orttraining/test/python/orttraining_test_debuggability.py b/orttraining/orttraining/test/python/orttraining_test_debuggability.py deleted file mode 100644 index 499f0ba7a1ff5..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_debuggability.py +++ /dev/null @@ -1,40 +0,0 @@ -import pytest -import torch -from _test_commons import _load_pytorch_transformer_model - -from onnxruntime import set_seed -from onnxruntime.training import optim, orttrainer - -############################################################################### -# Testing starts here ######################################################### -############################################################################### - - -@pytest.mark.parametrize( - "seed, device", - [ - (24, "cuda"), - ], -) -def testORTTransformerModelExport(seed, device): - # Common setup - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": { - "check_model_export": True, - }, - "device": { - "id": device, - }, - } - ) - - # Setup for the first ORTTRainer run - torch.manual_seed(seed) - set_seed(seed) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) - first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - data, targets = batcher_fn(train_data, 0) - _ = first_trainer.train_step(data, targets) - assert first_trainer._onnx_model is not None diff --git a/orttraining/orttraining/test/python/orttraining_test_dort.py b/orttraining/orttraining/test/python/orttraining_test_dort.py index 88d9c00984d3e..2a7012787be6e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_dort.py +++ b/orttraining/orttraining/test/python/orttraining_test_dort.py @@ -19,6 +19,7 @@ class TestTorchDynamoOrt(unittest.TestCase): def setUp(self): # Make computation deterministic. torch.manual_seed(42) + print(f"TestTorchDynamoOrt uses PyTorch version {torch.__version__}") def test_elementwise_model(self): torch._dynamo.reset() diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis.py index 506aafbe9f618..a3e666dd404f2 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis.py @@ -27,7 +27,7 @@ def run_training_apis_python_api_tests(cwd, log): log.debug("Running: ort training api tests") - command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_python_bindings.py"] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ort_apis_py_bindings.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() @@ -37,7 +37,7 @@ def run_onnxblock_tests(cwd, log): log.debug("Running: onnxblock tests") - command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_onnxblock.py"] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ort_apis_onnxblock.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() diff --git a/orttraining/orttraining/test/python/orttraining_test_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py similarity index 97% rename from orttraining/orttraining/test/python/orttraining_test_onnxblock.py rename to orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py index f7a7220dd66ea..6e5d54cbb9427 100644 --- a/orttraining/orttraining/test/python/orttraining_test_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py @@ -17,6 +17,14 @@ # PyTorch Module definitions +def get_opsets_model(filename): + if isinstance(filename, onnx.ModelProto): + onx = filename + else: + onx = onnx.load(filename) + return {d.domain: d.version for d in onx.opset_import} + + class SimpleNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super().__init__() @@ -999,3 +1007,13 @@ def test_save_ort_format(): assert os.path.exists(os.path.join(temp_dir, "eval_model.ort")) assert os.path.exists(os.path.join(temp_dir, "optimizer_model.onnx")) assert os.path.exists(os.path.join(temp_dir, "optimizer_model.ort")) + base_opsets = get_opsets_model(base_model) + training_opsets = get_opsets_model(os.path.join(temp_dir, "training_model.onnx")) + eval_opsets = get_opsets_model(os.path.join(temp_dir, "eval_model.onnx")) + optimizer_opsets = get_opsets_model(os.path.join(temp_dir, "optimizer_model.onnx")) + if base_opsets[""] != training_opsets[""]: + raise AssertionError(f"Opsets mismatch {base_opsets['']} != {training_opsets['']}.") + if base_opsets[""] != eval_opsets[""]: + raise AssertionError(f"Opsets mismatch {base_opsets['']} != {eval_opsets['']}.") + if base_opsets[""] != optimizer_opsets[""]: + raise AssertionError(f"Opsets mismatch {base_opsets['']} != {optimizer_opsets['']}.") diff --git a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py similarity index 99% rename from orttraining/orttraining/test/python/orttraining_test_python_bindings.py rename to orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py index d5c37b3e36ee7..34d8c24ccfab4 100644 --- a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py @@ -11,7 +11,7 @@ import onnx import pytest import torch -from orttraining_test_onnxblock import _get_models +from orttraining_test_ort_apis_onnxblock import _get_models import onnxruntime.training.onnxblock as onnxblock from onnxruntime import OrtValue, SessionOptions diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 13024b81f4b3c..ad0e5d8beba3d 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -5761,6 +5761,7 @@ def run_step(model, input, positions): ("MatMul", 1), ("Dropout", 0), ("LayerNormalization", 0), + ("LayerNormalization", 1), ("Cast", 0), ("BiasGelu", 0), ("Gelu", 0), @@ -5773,12 +5774,18 @@ def test_ops_for_padding_elimination(test_cases): test_op = test_cases[0] case = test_cases[1] + vocab_size, hidden_size = 50265, 768 + batch_size, max_seq_length = 8, 128 + class ToyModel(torch.nn.Module): def __init__(self, vocab_size, hidden_size, pad_token_id): super().__init__() self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id) if test_op == "LayerNormalization": - self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-05) + if case == 0: + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-05) + else: + self.LayerNorm = nn.LayerNorm([max_seq_length, hidden_size], eps=1e-05) self.hidden_size = hidden_size # test test_elementwise op for padding elimination @@ -5889,8 +5896,6 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): batched_inputs.append(torch.cat((input_id, padding))) return torch.stack(batched_inputs) - vocab_size, hidden_size = 50265, 768 - batch_size, max_seq_length = 8, 128 device = "cuda" model = ORTModule(ToyModel(vocab_size, hidden_size, 1).to(device)) x = generate_inputs(batch_size, max_seq_length, vocab_size) @@ -5908,7 +5913,7 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 3 else: assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 2 - gathergrad_node = next(node for node in training_model.graph.node if node.op_type == "PadAndUnflatten") + recover_pad_node = next(node for node in training_model.graph.node if node.op_type == "PadAndUnflatten") def find_input_node_type(model, arg): result = [] @@ -5917,14 +5922,14 @@ def find_input_node_type(model, arg): result.append(node) return result[0].op_type if len(result) == 1 else None - gathergrad_input_optypes = [find_input_node_type(training_model, arg) for arg in gathergrad_node.input] + recover_pad_input_optypes = [find_input_node_type(training_model, arg) for arg in recover_pad_node.input] if test_op == "Add" or test_op == "Mul" or test_op == "Sub": - assert test_op in gathergrad_input_optypes + assert test_op in recover_pad_input_optypes else: if case == 0: - assert test_op in gathergrad_input_optypes + assert test_op in recover_pad_input_optypes else: - assert "ATen" in gathergrad_input_optypes + assert "ATen" in recover_pad_input_optypes del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] diff --git a/orttraining/orttraining/test/python/orttraining_test_hooks.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_hooks.py similarity index 100% rename from orttraining/orttraining/test/python/orttraining_test_hooks.py rename to orttraining/orttraining/test/python/orttraining_test_ortmodule_hooks.py diff --git a/orttraining/orttraining/test/python/orttraining_test_onnx_ops_ortmodule.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py similarity index 100% rename from orttraining/orttraining/test/python/orttraining_test_onnx_ops_ortmodule.py rename to orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py index d205e8f2377dd..0c381d70ca4c1 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py @@ -5,6 +5,7 @@ import json import os import random +import uuid import _test_helpers import onnx @@ -200,7 +201,8 @@ def _run_op_test(op_type, onnx_dtype, create_model_func, gen_inputs_func, **kwar _, op_type = op_type.split("::") pt_outputs = TorchFuncExecutor.run(op_type, *pt_inputs, **kwargs) model_str = create_model_func(op_type, onnx_dtype, **kwargs).SerializeToString() - ort_outputs = call_triton_by_onnx(hash(model_str), model_str, *[to_dlpack(tensor) for tensor in ort_inputs]) + unique_id = uuid.uuid1().int >> 64 + ort_outputs = call_triton_by_onnx(unique_id, model_str, *[to_dlpack(tensor) for tensor in ort_inputs]) if isinstance(pt_outputs, tuple): assert isinstance(ort_outputs, tuple) assert len(pt_outputs) == len(ort_outputs) @@ -229,9 +231,9 @@ def _run_module_test(module_cls, dtype, gen_inputs_func, triton_op_count, **kwar ort_output = _run_step(ort_model, *ort_inputs) _test_helpers.assert_values_are_close(pt_output, ort_output, rtol=rtol, atol=atol) _test_helpers.assert_gradients_match_and_reset_gradient(pt_model, ort_model, rtol=rtol, atol=atol) - for i in range(len(pt_inputs)): - if pt_inputs[i].requires_grad: - _test_helpers.assert_values_are_close(pt_inputs[i].grad, ort_inputs[i].grad, rtol=rtol, atol=atol) + for idx, pt_input in enumerate(pt_inputs): + if pt_input.requires_grad: + _test_helpers.assert_values_are_close(pt_input.grad, ort_inputs[idx].grad, rtol=rtol, atol=atol) assert os.path.exists(os.path.join(os.getcwd(), "triton_model_torch_exported_training.onnx")) assert os.path.exists(os.path.join(os.getcwd(), "triton_model_optimized_training.onnx")) @@ -250,12 +252,12 @@ def _run_module_test(module_cls, dtype, gen_inputs_func, triton_op_count, **kwar def _run_tunable_op_test(module_cls, dtype, gen_inputs_func, tunable_op, impl_count, **kwargs): + os.environ["ORTMODULE_ENABLE_TUNING"] = "1" + os.environ["ORTMODULE_TUNING_RESULTS_PATH"] = "./" pt_model = module_cls().to(DEVICE).to(dtype) ort_model = ORTModule(copy.deepcopy(pt_model)) rtol = kwargs.get("rtol", 1e-03 if dtype == torch.float16 else 1e-04) atol = kwargs.get("atol", 1e-03 if dtype == torch.float16 else 1e-05) - os.environ["ORTMODULE_ENABLE_TUNING"] = "1" - os.environ["ORTMODULE_TUNING_RESULTS_PATH"] = "./" for _ in range(5): pt_inputs = gen_inputs_func(dtype) ort_inputs = copy.deepcopy(pt_inputs) @@ -265,7 +267,7 @@ def _run_tunable_op_test(module_cls, dtype, gen_inputs_func, tunable_op, impl_co _test_helpers.assert_gradients_match_and_reset_gradient(pt_model, ort_model, rtol=rtol, atol=atol) tunable_results_file = os.path.join(os.getcwd(), "tuning_results_training.json") assert os.path.exists(tunable_results_file) - with open(tunable_results_file) as f: + with open(tunable_results_file, encoding="UTF-8") as f: tunable_results = json.load(f) assert tunable_op in str(tunable_results) del os.environ["ORTMODULE_ENABLE_TUNING"] @@ -275,7 +277,7 @@ def _run_tunable_op_test(module_cls, dtype, gen_inputs_func, tunable_op, impl_co if tunable_op in k: for param, impl in v.items(): v[param] = (impl + 1 + i) % impl_count - with open(tunable_results_file, "w") as f: + with open(tunable_results_file, "w", encoding="UTF-8") as f: json.dump(new_tunable_results, f) ort_model = ORTModule(copy.deepcopy(pt_model)) for _ in range(5): @@ -541,6 +543,8 @@ def _gen_inputs(dtype): ([123, 4, 5, 6], [2], False), ([16, 8, 16, 8], [1, 3], True), ([16, 8, 16, 8], [0, 2], False), + ([16, 8, 16, 8], [0, 1, 2, 3], True), + ([16, 1, 16, 8], [0, 1, 2, 3], False), ], ) def test_reduce_op(op_type, onnx_dtype, input_shape_and_reduce_info): @@ -781,6 +785,43 @@ def _gen_inputs(dtype): _run_module_test(NeuralNetLayerNorm, dtype, _gen_inputs, 2) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_dynamic_shapes_elementwise_module(dtype): + class NeuralNetSymbolicShapesElementwise(torch.nn.Module): + def forward(self, x, y, u, v): + return x * y - (u + v) + + def _gen_inputs(dtype): + dim1 = 64 * random.randint(2, 4) + dim2 = 64 * random.randint(2, 4) + return [ + torch.rand(16, dim1, dim2, dtype=dtype, device=DEVICE, requires_grad=True), + torch.rand(16, 1, dim2, dtype=dtype, device=DEVICE, requires_grad=True), + torch.rand(dim1, 1, dtype=dtype, device=DEVICE, requires_grad=True), + torch.rand(16, dim1, dim2, dtype=dtype, device=DEVICE, requires_grad=True), + ] + + _run_module_test(NeuralNetSymbolicShapesElementwise, dtype, _gen_inputs, 1) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_dynamic_shapes_reduction_module(dtype): + class NeuralNetSymbolicShapesReduction(torch.nn.Module): + def forward(self, x, y, z): + return torch.softmax(x * y + z, dim=-1) + + def _gen_inputs(dtype): + dim1 = 64 * random.randint(2, 4) + dim2 = 64 * random.randint(2, 4) + return [ + torch.rand(16, dim1, dim2, dtype=dtype, device=DEVICE, requires_grad=True), + torch.rand(16, 1, dim2, dtype=dtype, device=DEVICE, requires_grad=True), + torch.rand(dim1, 1, dtype=dtype, device=DEVICE, requires_grad=True), + ] + + _run_module_test(NeuralNetSymbolicShapesReduction, dtype, _gen_inputs, 2) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) @pytest.mark.parametrize("has_sum", [True, False]) def test_slice_scel_module(dtype, has_sum): @@ -832,3 +873,34 @@ def _gen_inputs(dtype): return [torch.rand(m_n_k[0], m_n_k[2], dtype=dtype, device=DEVICE, requires_grad=True)] _run_tunable_op_test(NeuralNetGemm, dtype, _gen_inputs, "GemmTunableOp", 2) + + +def test_user_config(): + n, d, h, w = 8, 768, 12, 64 + dtype = torch.float32 + + class NeuralNetElementwise(torch.nn.Module): + def forward(self, input1, input2, input3, input4): + return input1 + input2 - input3 * input4 + + def _gen_inputs(dtype): + return [ + torch.rand(n, d, h, w, dtype=dtype, device=DEVICE, requires_grad=True), + torch.rand(w, dtype=dtype, device=DEVICE, requires_grad=True), + torch.rand(d, 1, 1, dtype=dtype, device=DEVICE, requires_grad=True), + torch.rand(n, 1, h, w, dtype=dtype, device=DEVICE, requires_grad=True), + ] + + user_config = ( + '{"ops": {"Add": {"versions": [13, 14]}, "Mul": {"versions": [13, 14]}}, ' + '"initializer": "scalar", "min_nodes": 2}' + ) + with open("user_config.json", "w", encoding="UTF-8") as f: + f.write(user_config) + os.environ["ORTMODULE_TRITON_CONFIG_FILE"] = "./user_config.json" + + # Mul is not supported, the graph is splited to 2 subgraphs with single Op, which will not be fused to TritonOp. + _run_module_test(NeuralNetElementwise, dtype, _gen_inputs, 0) + + del os.environ["ORTMODULE_TRITON_CONFIG_FILE"] + os.remove(os.path.join(os.getcwd(), "user_config.json")) diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py deleted file mode 100644 index 45b87b32f7d64..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py +++ /dev/null @@ -1,1283 +0,0 @@ -import copy # noqa: F401 -import inspect # noqa: F401 -import math # noqa: F401 -import os -from functools import partial - -import _test_commons -import _test_helpers -import onnx -import pytest -import torch -from numpy.testing import assert_allclose - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription -from onnxruntime.capi.ort_trainer import LossScaler as Legacy_LossScaler -from onnxruntime.capi.ort_trainer import ModelDescription as Legacy_ModelDescription -from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer -from onnxruntime.training import amp, optim, orttrainer - -############################################################################### -# Helper functions ############################################################ -############################################################################### - - -def generate_random_input_from_model_desc(desc, seed=1, device="cuda:0"): - """Generates a sample input for the BERT model using the model desc""" - - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - dtype = torch.int64 - vocab_size = 30528 - num_classes = [vocab_size, 2, 2, vocab_size, 2] - dims = {"batch_size": 16, "seq_len": 1} - sample_input = [] - for index, input in enumerate(desc["inputs"]): - size = [] - for s in input[1]: - if isinstance(s, (int)): - size.append(s) - else: - size.append(dims[s] if s in dims else 1) - sample_input.append(torch.randint(0, num_classes[index], tuple(size), dtype=dtype).to(device)) - return sample_input - - -# EXPERIMENTAL HELPER FUNCTIONS - - -def bert_model_description(dynamic_shape=True): - """Creates the model description dictionary with static dimensions""" - - if dynamic_shape: - model_desc = { - "inputs": [ - ("input_ids", ["batch_size", "seq_len"]), - ( - "segment_ids", - ["batch_size", "seq_len"], - ), - ( - "input_mask", - ["batch_size", "seq_len"], - ), - ( - "masked_lm_labels", - ["batch_size", "seq_len"], - ), - ( - "next_sentence_labels", - [ - "batch_size", - ], - ), - ], - "outputs": [("loss", [], True)], - } - else: - batch_size = 16 - seq_len = 1 - model_desc = { - "inputs": [ - ("input_ids", [batch_size, seq_len]), - ( - "segment_ids", - [batch_size, seq_len], - ), - ( - "input_mask", - [batch_size, seq_len], - ), - ( - "masked_lm_labels", - [batch_size, seq_len], - ), - ( - "next_sentence_labels", - [ - batch_size, - ], - ), - ], - "outputs": [("loss", [], True)], - } - return model_desc - - -def optimizer_parameters(model): - """A method to assign different hyper parameters for different model parameter groups""" - - no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] - no_decay_param_group = [] - for initializer in model.graph.initializer: - if any(key in initializer.name for key in no_decay_keys): - no_decay_param_group.append(initializer.name) - params = [ - { - "params": no_decay_param_group, - "alpha": 0.9, - "beta": 0.999, - "lambda_coef": 0.0, - "epsilon": 1e-6, - "do_bias_correction": False, - } - ] - - return params - - -def load_bert_onnx_model(): - bert_onnx_model_path = os.path.join("testdata", "bert_toy_postprocessed.onnx") - model = onnx.load(bert_onnx_model_path) - return model - - -class CustomLossScaler(amp.LossScaler): - def __init__(self, loss_scale=float(1 << 16)): - super().__init__(loss_scale) - self._initial_loss_scale = loss_scale - self.loss_scale = loss_scale - - def reset(self): - self.loss_scale = self._initial_loss_scale - - def update(self, train_step_info): - self.loss_scale *= 0.9 - return self.loss_scale - - -# LEGACY HELPER FUNCTIONS - - -class LegacyCustomLossScaler: - def __init__(self, loss_scale=float(1 << 16)): - self._initial_loss_scale = loss_scale - self.loss_scale_ = loss_scale - - def reset(self): - self.loss_scale_ = self._initial_loss_scale - - def update_loss_scale(self, is_all_finite): - self.loss_scale_ *= 0.9 - - -def legacy_model_params(lr, device=torch.device("cuda", 0)): # noqa: B008 - legacy_model_desc = legacy_bert_model_description() - learning_rate_description = legacy_ort_trainer_learning_rate_description() - learning_rate = torch.tensor([lr]).to(device) - return (legacy_model_desc, learning_rate_description, learning_rate) - - -def legacy_ort_trainer_learning_rate_description(): - return Legacy_IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ) - - -def legacy_bert_model_description(): - input_ids_desc = Legacy_IODescription("input_ids", ["batch", "max_seq_len_in_batch"]) - segment_ids_desc = Legacy_IODescription("segment_ids", ["batch", "max_seq_len_in_batch"]) - input_mask_desc = Legacy_IODescription("input_mask", ["batch", "max_seq_len_in_batch"]) - masked_lm_labels_desc = Legacy_IODescription("masked_lm_labels", ["batch", "max_seq_len_in_batch"]) - next_sentence_labels_desc = Legacy_IODescription( - "next_sentence_labels", - [ - "batch", - ], - ) - loss_desc = Legacy_IODescription("loss", []) - - return Legacy_ModelDescription( - [input_ids_desc, segment_ids_desc, input_mask_desc, masked_lm_labels_desc, next_sentence_labels_desc], - [loss_desc], - ) - - -def legacy_optim_params_a(name): - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6, "do_bias_correction": False} - - -def legacy_optim_params_b(name): - params = ["bert.embeddings.LayerNorm.bias", "bert.embeddings.LayerNorm.weight"] - if name in params: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6, "do_bias_correction": False} - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6, "do_bias_correction": False} - - -def legacy_optim_params_c(name): - params_group = optimizer_parameters(load_bert_onnx_model()) - if name in params_group[0]["params"]: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6, "do_bias_correction": False} - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6, "do_bias_correction": False} - - -############################################################################### -# Testing starts here ######################################################### -############################################################################### - - -@pytest.mark.parametrize("dynamic_shape", [(True), (False)]) -def testToyBERTModelBasicTraining(dynamic_shape): - model_desc = bert_model_description(dynamic_shape) - model = load_bert_onnx_model() - - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions({}) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - for _i in range(10): - sample_input = generate_random_input_from_model_desc(model_desc) - output = trainer.train_step(*sample_input) - assert output.shape == torch.Size([]) - - -@pytest.mark.parametrize( - "expected_losses", - [([11.041123, 10.986166, 11.101636, 11.013366, 11.03775, 11.041175, 10.957118, 11.069563, 11.040824, 11.16437])], -) -def testToyBERTDeterministicCheck(expected_losses): - # Common setup - train_steps = 10 - device = "cuda" - seed = 1 - rtol = 1e-3 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # Modeling - model_desc = bert_model_description() - model = load_bert_onnx_model() - optimizer_parameters(model) - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train - experimental_losses = [] - for i in range(train_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - - # Check output - _test_helpers.assert_model_outputs(experimental_losses, expected_losses, rtol=rtol) - - -@pytest.mark.parametrize( - "initial_lr, lr_scheduler, expected_learning_rates, expected_losses", - [ - ( - 1.0, - optim.lr_scheduler.ConstantWarmupLRScheduler, - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - [ - 10.988012313842773, - 10.99213981628418, - 120.79301452636719, - 36.11647033691406, - 95.83200073242188, - 221.2766571044922, - 208.40316772460938, - 279.5332946777344, - 402.46380615234375, - 325.79254150390625, - ], - ), - ( - 0.5, - optim.lr_scheduler.ConstantWarmupLRScheduler, - [0.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], - [ - 10.988012313842773, - 10.99213981628418, - 52.69743347167969, - 19.741533279418945, - 83.88340759277344, - 126.39848327636719, - 91.53898620605469, - 63.62016296386719, - 102.21206665039062, - 180.1424560546875, - ], - ), - ( - 1.0, - optim.lr_scheduler.CosineWarmupLRScheduler, - [ - 0.0, - 0.9931806517013612, - 0.9397368756032445, - 0.8386407858128706, - 0.7008477123264848, - 0.5412896727361662, - 0.37725725642960045, - 0.22652592093878665, - 0.10542974530180327, - 0.02709137914968268, - ], - [ - 10.988012313842773, - 10.99213981628418, - 120.6441650390625, - 32.152557373046875, - 89.63705444335938, - 138.8782196044922, - 117.57748413085938, - 148.01927185058594, - 229.60403442382812, - 110.2930908203125, - ], - ), - ( - 1.0, - optim.lr_scheduler.LinearWarmupLRScheduler, - [ - 0.0, - 0.9473684210526315, - 0.8421052631578947, - 0.7368421052631579, - 0.631578947368421, - 0.5263157894736842, - 0.42105263157894735, - 0.3157894736842105, - 0.21052631578947367, - 0.10526315789473684, - ], - [ - 10.988012313842773, - 10.99213981628418, - 112.89633178710938, - 31.114538192749023, - 80.94029235839844, - 131.34490966796875, - 111.4329605102539, - 133.74252319335938, - 219.37344360351562, - 109.67041015625, - ], - ), - ( - 1.0, - optim.lr_scheduler.PolyWarmupLRScheduler, - [ - 0.0, - 0.9473684263157895, - 0.8421052789473684, - 0.7368421315789474, - 0.6315789842105263, - 0.5263158368421054, - 0.42105268947368424, - 0.31578954210526317, - 0.21052639473684212, - 0.10526324736842106, - ], - [ - 10.988012313842773, - 10.99213981628418, - 112.89633178710938, - 31.114538192749023, - 80.9402847290039, - 131.3447265625, - 111.43253326416016, - 133.7415008544922, - 219.37147521972656, - 109.66986083984375, - ], - ), - ], -) -def testToyBERTModelLRScheduler(initial_lr, lr_scheduler, expected_learning_rates, expected_losses): - return # TODO: re-enable after nondeterminism on backend is fixed - # Common setup - device = "cuda" - total_steps = 10 - seed = 1 - warmup = 0.05 - cycles = 0.5 - power = 1.0 - lr_end = 1e-7 - rtol = 1e-3 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # Setup LR Schedulers - if ( - lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler - or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler - ): - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) - elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) - elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) - else: - raise RuntimeError("Invalid lr_scheduler") - - # Modeling - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.AdamConfig(lr=initial_lr) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "lr_scheduler": lr_scheduler, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train - losses = [] - learning_rates = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - losses.append(trainer.train_step(*sample_input).cpu().item()) - learning_rates.append(trainer.options.lr_scheduler.get_last_lr()[0]) - - # Check output - _test_helpers.assert_model_outputs(learning_rates, expected_learning_rates, rtol=rtol) - _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol) - - -@pytest.mark.parametrize( - "loss_scaler, expected_losses", - [ - ( - None, - [ - 11.041126, - 10.986309, - 11.101673, - 11.013394, - 11.037781, - 11.041253, - 10.957072, - 11.069506, - 11.040807, - 11.164349, - ], - ), - ( - amp.DynamicLossScaler(), - [ - 11.041126, - 10.986309, - 11.101673, - 11.013394, - 11.037781, - 11.041253, - 10.957072, - 11.069506, - 11.040807, - 11.164349, - ], - ), - ( - CustomLossScaler(), - [ - 11.041126, - 10.986309, - 11.101645, - 11.013412, - 11.037757, - 11.041273, - 10.957077, - 11.069525, - 11.040765, - 11.164298, - ], - ), - ], -) -def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses): - # Common setup - total_steps = 10 - device = "cuda" - seed = 1 - rtol = 1e-3 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # Modeling - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train - losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - losses.append(trainer.train_step(*sample_input).cpu().item()) - - # Check output - _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol) - - -@pytest.mark.parametrize( - "gradient_accumulation_steps, expected_losses", - [ - ( - 1, - [ - 11.041123, - 10.986166, - 11.101636, - 11.013366, - 11.03775, - 11.041175, - 10.957118, - 11.069563, - 11.040824, - 11.16437, - ], - ), - ( - 4, - [ - 11.041123, - 10.982856, - 11.105512, - 11.006721, - 11.03358, - 11.05058, - 10.955864, - 11.059035, - 11.037753, - 11.162649, - ], - ), - ( - 7, - [ - 11.041123, - 10.982856, - 11.105512, - 11.006721, - 11.036314, - 11.055109, - 10.960751, - 11.05809, - 11.038856, - 11.159635, - ], - ), - ], -) -def testToyBERTModelGradientAccumulation(gradient_accumulation_steps, expected_losses): - # Common setup - total_steps = 10 - device = "cuda" - seed = 1 - rtol = 1e-3 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # Modeling - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train - losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - losses.append(trainer.train_step(*sample_input).cpu().item()) - - # Check output - _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol) - - -def testToyBertCheckpointBasic(): - # Common setup - seed = 1 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions({"debug": {"deterministic_compute": True}}) - - # Create ORTTrainer and save initial state in a dict - model = load_bert_onnx_model() - model_desc = bert_model_description() - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - sd = trainer.state_dict() - - ## All initializers must be present in the state_dict - ## when the specified model for ORTTRainer is an ONNX model - for param in trainer._onnx_model.graph.initializer: - assert param.name in sd["model"]["full_precision"] - - ## Modify one of the state values and load into ORTTrainer - sd["model"]["full_precision"]["bert.encoder.layer.0.attention.output.LayerNorm.weight"] += 10 - trainer.load_state_dict(sd) - - ## Save a checkpoint - ckpt_dir = "testdata" - trainer.save_checkpoint(os.path.join(ckpt_dir, "bert_toy_save_test.ortcp")) - del trainer - del model - - # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer - model2 = load_bert_onnx_model() - model_desc2 = bert_model_description() - trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config, options=opts) - trainer2.load_checkpoint(os.path.join(ckpt_dir, "bert_toy_save_test.ortcp")) - loaded_sd = trainer2.state_dict() - - # Assert whether original state and the one loaded from checkpoint matches - _test_commons.assert_all_states_close_ort(sd, loaded_sd) - - -def testToyBertCheckpointFrozenWeights(): - # Common setup - seed = 1 - total_steps = 10 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "utils": {"frozen_weights": ["bert.encoder.layer.0.attention.self.value.weight"]}, - } - ) - - # Create ORTTrainer and save initial state in a dict - model = load_bert_onnx_model() - model_desc = bert_model_description() - optim_config = optim.LambConfig() - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train for a few steps - for _i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, seed) - _ = trainer.train_step(*sample_input) - sample_input = generate_random_input_from_model_desc(model_desc, seed + total_steps + 1) - # Evaluate once to get a base loss - loss = trainer.eval_step(*sample_input) - # Save checkpoint - state_dict = trainer.state_dict() - - # Load previous state into another instance of ORTTrainer - model2 = load_bert_onnx_model() - model_desc2 = bert_model_description() - optim_config2 = optim.LambConfig() - trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config2, options=opts) - trainer2.load_state_dict(state_dict) - # Evaluate once to get a base loss - ckpt_loss = trainer2.eval_step(*sample_input) - - # Must match as both trainers have the same dict state - assert_allclose(loss.cpu(), ckpt_loss.cpu()) - loaded_state_dict = trainer2.state_dict() - _test_commons.assert_all_states_close_ort(state_dict, loaded_state_dict) - - -@pytest.mark.parametrize( - "optimizer, mixedprecision_enabled", - [ - (optim.LambConfig(), False), - (optim.AdamConfig(), False), - (optim.LambConfig(), True), - (optim.AdamConfig(), True), - ], -) -def testToyBertLoadOptimState(optimizer, mixedprecision_enabled): - # Common setup - device = "cuda" - seed = 1 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - optim_config = optimizer - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": {"id": device}, - "mixed_precision": { - "enabled": mixedprecision_enabled, - }, - "distributed": {"allreduce_post_accumulation": True}, - } - ) - - # Create ORTTrainer and save initial state in a dict - model = load_bert_onnx_model() - model_desc = bert_model_description() - dummy_init_state = _test_commons.generate_dummy_optim_state(model, optimizer) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - trainer.load_state_dict(dummy_init_state) - - # Expected values - input_ids = torch.tensor( - [ - [26598], - [21379], - [19922], - [5219], - [5644], - [20559], - [23777], - [25672], - [22969], - [16824], - [16822], - [635], - [27399], - [20647], - [18519], - [15546], - ], - device=device, - ) - segment_ids = torch.tensor( - [[0], [1], [0], [1], [0], [0], [1], [0], [0], [1], [1], [0], [0], [1], [1], [1]], device=device - ) - input_mask = torch.tensor( - [[0], [0], [0], [0], [1], [1], [1], [0], [1], [1], [0], [0], [0], [1], [0], [0]], device=device - ) - masked_lm_labels = torch.tensor( - [ - [25496], - [16184], - [11005], - [16228], - [14884], - [21660], - [8678], - [23083], - [4027], - [8397], - [11921], - [1333], - [26482], - [1666], - [17925], - [27978], - ], - device=device, - ) - next_sentence_labels = torch.tensor([0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device) - - # Actual values - _ = trainer.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels) - - actual_state_dict = trainer.state_dict() - del actual_state_dict["model"] - _test_commons.assert_all_states_close_ort(actual_state_dict, dummy_init_state) - - -@pytest.mark.parametrize( - "model_params", - [ - (["bert.embeddings.LayerNorm.bias"]), - ( - [ - "bert.embeddings.LayerNorm.bias", - "bert.embeddings.LayerNorm.weight", - "bert.encoder.layer.0.attention.output.LayerNorm.bias", - ] - ), - ], -) -def testORTTrainerFrozenWeights(model_params): - device = "cuda" - total_steps = 10 - seed = 1 - - # EXPERIMENTAL API - model_desc = bert_model_description() - model = load_bert_onnx_model() - - optim_config = optim.LambConfig() - # Setup ORTTrainer WITHOUT frozen weights - opts_dict = { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - opts = orttrainer.ORTTrainerOptions(opts_dict) - - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - trainer.train_step(*sample_input) - - # All model_params must be in the session state - assert trainer._onnx_model is not None - session_state = trainer._training_session.get_state() - assert all([param in session_state for param in model_params]) - - # Setup ORTTrainer WITH frozen weights - opts_dict.update({"utils": {"frozen_weights": model_params}}) - opts = orttrainer.ORTTrainerOptions(opts_dict) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - trainer.train_step(*sample_input) - - # All model_params CANNOT be in the session state - assert trainer._onnx_model is not None - session_state = trainer._training_session.get_state() - assert not any([param in session_state for param in model_params]) - - -def testToyBERTSaveAsONNX(): - device = "cuda" - onnx_file_name = "_____temp_toy_bert_onnx_model.onnx" - if os.path.exists(onnx_file_name): - os.remove(onnx_file_name) - assert not os.path.exists(onnx_file_name) - - # Load trainer - model_desc = bert_model_description() - model = load_bert_onnx_model() - - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - ) - - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - trainer.save_as_onnx(onnx_file_name) - assert os.path.exists(onnx_file_name) - - with open(onnx_file_name, "rb") as f: - bin_str = f.read() - reload_onnx_model = onnx.load_model_from_string(bin_str) - os.remove(onnx_file_name) - - # Create a new trainer from persisted ONNX model and compare with original ONNX model - trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config, options=opts) - assert trainer_from_onnx._onnx_model is not None - assert id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model) - for initializer, loaded_initializer in zip( - trainer._onnx_model.graph.initializer, trainer_from_onnx._onnx_model.graph.initializer - ): - assert initializer.name == loaded_initializer.name - assert onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph( - trainer._onnx_model.graph - ) - _test_helpers.assert_onnx_weights(trainer, trainer_from_onnx) - - -############################################################################### -# Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############ -############################################################################### -@pytest.mark.parametrize( - "optimizer_config", - [ - (optim.AdamConfig), - # (optim.LambConfig), # TODO: re-enable after nondeterminism on backend is fixed - (optim.SGDConfig), - ], -) -def testToyBERTModelLegacyExperimentalBasicTraining(optimizer_config): - # Common setup - train_steps = 512 - - device = "cuda" - seed = 1 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # EXPERIMENTAL API - model_desc = bert_model_description() - model = load_bert_onnx_model() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - ) - optim_config = optimizer_config(lr=0.01) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - experimental_losses = [] - for i in range(train_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - if optimizer_config == optim.AdamConfig: - legacy_optimizer = "AdamOptimizer" - elif optimizer_config == optim.LambConfig: - legacy_optimizer = "LambOptimizer" - elif optimizer_config == optim.SGDConfig: - legacy_optimizer = "SGDOptimizer" - else: - raise RuntimeError("Invalid optimizer_config") - - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(lr=optim_config.lr) - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - legacy_optimizer, - None, - learning_rate_description, - device, - _use_deterministic_compute=True, - ) - legacy_losses = [] - for i in range(train_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) - legacy_losses.append(leg_loss.cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses, True) - - -@pytest.mark.parametrize( - "initial_lr, lr_scheduler, legacy_lr_scheduler", - [ - (1.0, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler), - (0.5, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler), - (1.0, optim.lr_scheduler.CosineWarmupLRScheduler, _test_commons.legacy_cosine_lr_scheduler), - (1.0, optim.lr_scheduler.LinearWarmupLRScheduler, _test_commons.legacy_linear_lr_scheduler), - (1.0, optim.lr_scheduler.PolyWarmupLRScheduler, _test_commons.legacy_poly_lr_scheduler), - ], -) -def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, legacy_lr_scheduler): - ############################################################################ - # These tests require hard-coded values for 'total_steps' and 'initial_lr' # - ############################################################################ - - # Common setup - total_steps = 128 - device = "cuda" - seed = 1 - warmup = 0.05 - cycles = 0.5 - power = 1.0 - lr_end = 1e-7 - - # Setup both Experimental and Legacy LR Schedulers before the experimental loop - if ( - legacy_lr_scheduler == _test_commons.legacy_constant_lr_scheduler - or legacy_lr_scheduler == _test_commons.legacy_linear_lr_scheduler - ): - legacy_lr_scheduler = partial( - legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup - ) - elif legacy_lr_scheduler == _test_commons.legacy_cosine_lr_scheduler: - legacy_lr_scheduler = partial( - legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup, cycles=cycles - ) - elif legacy_lr_scheduler == _test_commons.legacy_poly_lr_scheduler: - legacy_lr_scheduler = partial( - legacy_lr_scheduler, - initial_lr=initial_lr, - total_steps=total_steps, - warmup=warmup, - power=power, - lr_end=lr_end, - ) - else: - raise RuntimeError("Invalid legacy_lr_scheduler") - if ( - lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler - or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler - ): - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) - elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) - elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) - else: - raise RuntimeError("Invalid lr_scheduler") - - # EXPERIMENTAL API - model_desc = bert_model_description() - model = load_bert_onnx_model() - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - optim_config = optim.AdamConfig(lr=initial_lr) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "lr_scheduler": lr_scheduler, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - experimental_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - assert_allclose(trainer.options.lr_scheduler.get_last_lr()[0], legacy_lr_scheduler(i)) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(initial_lr) - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - "AdamOptimizer", - None, - learning_rate_description, - device, - _use_deterministic_compute=True, - get_lr_this_step=legacy_lr_scheduler, - ) - legacy_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - leg_loss = legacy_trainer.train_step(*sample_input) - legacy_losses.append(leg_loss.cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) - - -@pytest.mark.parametrize( - "loss_scaler, legacy_loss_scaler", - [ - (None, Legacy_LossScaler("ort_test_input_loss_scaler", True)), - (amp.DynamicLossScaler(), Legacy_LossScaler("ort_test_input_loss_scaler", True)), - (CustomLossScaler(), LegacyCustomLossScaler()), - ], -) -def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental(loss_scaler, legacy_loss_scaler): - # Common setup - total_steps = 128 - device = "cuda" - seed = 1 - - # EXPERIMENTAL IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.AdamConfig(lr=0.001) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - experimental_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(optim_config.lr) - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - "AdamOptimizer", - None, - learning_rate_description, - device, - _use_deterministic_compute=True, - use_mixed_precision=True, - loss_scaler=legacy_loss_scaler, - ) - legacy_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) - legacy_losses.append(leg_loss.cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) - - -@pytest.mark.parametrize("gradient_accumulation_steps", [(1), (4), (7)]) -def testToyBERTModelGradientAccumulationLegacyExperimental(gradient_accumulation_steps): - # Common setup - total_steps = 128 - device = "cuda" - seed = 1 - - # EXPERIMENTAL IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.AdamConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - experimental_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - loss = trainer.train_step(*sample_input) - experimental_losses.append(loss.cpu().item()) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(optim_config.lr) - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - "AdamOptimizer", - None, - learning_rate_description, - device, - _use_deterministic_compute=True, - gradient_accumulation_steps=gradient_accumulation_steps, - ) - legacy_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) - legacy_losses.append(leg_loss.cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) - - -@pytest.mark.parametrize( - "params, legacy_optim_map", - [ - # Change the hyper parameters for all parameters - ([], legacy_optim_params_a), - # Change the hyperparameters for a subset of hardcoded parameters - ( - [ - { - "params": ["bert.embeddings.LayerNorm.bias", "bert.embeddings.LayerNorm.weight"], - "alpha": 0.9, - "beta": 0.999, - "lambda_coef": 0.0, - "epsilon": 1e-6, - "do_bias_correction": False, - } - ], - legacy_optim_params_b, - ), - # Change the hyperparameters for a generated set of paramers - (optimizer_parameters(load_bert_onnx_model()), legacy_optim_params_c), - ], -) -def testToyBERTModelLegacyExperimentalCustomOptimParameters(params, legacy_optim_map): - # Common setup - total_steps = 128 - device = "cuda" - seed = 1 - - # EXPERIMENTAL API - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - model_desc = bert_model_description() - model = load_bert_onnx_model() - - optim_config = optim.AdamConfig( - params, alpha=0.9, beta=0.999, lambda_coef=0.01, epsilon=1e-6, do_bias_correction=False - ) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - experimental_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(trainer.optim_config.lr) - - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - "AdamOptimizer", - legacy_optim_map, - learning_rate_description, - device, - _use_deterministic_compute=True, - ) - legacy_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - legacy_sample_input = [*sample_input, learning_rate] - legacy_losses.append(legacy_trainer.train_step(legacy_sample_input).cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py deleted file mode 100644 index d366f2cb26557..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py +++ /dev/null @@ -1,722 +0,0 @@ -from unittest.mock import Mock, patch - -import numpy as np -import onnx -import pytest -import torch -from _test_commons import _load_pytorch_transformer_model - -from onnxruntime.training import _checkpoint_storage, amp, checkpoint, optim, orttrainer # noqa: F401 - -# Helper functions - - -def _create_trainer(zero_enabled=False): - """Cerates a simple ORTTrainer for ORTTrainer functional tests""" - - device = "cuda" - optim_config = optim.LambConfig(lr=0.1) - opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} - if zero_enabled: - opts["distributed"] = { - "world_rank": 0, - "world_size": 1, - "horizontal_parallel_size": 1, - "data_parallel_size": 1, - "allreduce_post_accumulation": True, - "deepspeed_zero_optimization": {"stage": 1}, - } - model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) - trainer = orttrainer.ORTTrainer( - model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(opts) - ) - - return trainer - - -class _training_session_mock: # noqa: N801 - """Mock object for the ORTTrainer _training_session member""" - - def __init__(self, model_states, optimizer_states, partition_info): - self.model_states = model_states - self.optimizer_states = optimizer_states - self.partition_info = partition_info - - def get_model_state(self, include_mixed_precision_weights=False): - return self.model_states - - def get_optimizer_state(self): - return self.optimizer_states - - def get_partition_info_map(self): - return self.partition_info - - -def _get_load_state_dict_strict_error_arguments(): - """Return a list of tuples that can be used as parameters for test_load_state_dict_errors_when_model_key_missing - - Construct a list of tuples (training_session_state_dict, input_state_dict, error_arguments) - The load_state_dict function will compare the two state dicts (training_session_state_dict, input_state_dict) and - throw a runtime error with the missing/unexpected keys. The error arguments capture these missing/unexpected keys. - """ - - training_session_state_dict = { - "model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}, - "optimizer": { - "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(5)}, - }, - } - - # input state dictionaries - precision_key_missing = {"model": {}, "optimizer": {}} - precision_key_unexpected = {"model": {"full_precision": {}, "mixed_precision": {}}, "optimizer": {}} - model_state_key_missing = {"model": {"full_precision": {}}, "optimizer": {}} - model_state_key_unexpected = {"model": {"full_precision": {"a": 2, "b": 3, "c": 4}}, "optimizer": {}} - optimizer_model_state_key_missing = {"model": {"full_precision": {"a": 2, "b": 3}}, "optimizer": {}} - optimizer_model_state_key_unexpected = { - "model": {"full_precision": {"a": 2, "b": 3}}, - "optimizer": {"a": {}, "shared_optimizer_state": {}, "b": {}}, - } - optimizer_state_key_missing = { - "model": {"full_precision": {"a": 2, "b": 3}}, - "optimizer": {"a": {}, "shared_optimizer_state": {"step": np.arange(5)}}, - } - optimizer_state_key_unexpected = { - "model": {"full_precision": {"a": 2, "b": 3}}, - "optimizer": { - "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(5), "another_step": np.arange(1)}, - }, - } - - input_arguments = [ - (training_session_state_dict, precision_key_missing, ["full_precision"]), - (training_session_state_dict, precision_key_unexpected, ["mixed_precision"]), - (training_session_state_dict, model_state_key_missing, ["a", "b"]), - (training_session_state_dict, model_state_key_unexpected, ["c"]), - (training_session_state_dict, optimizer_model_state_key_missing, ["a", "shared_optimizer_state"]), - (training_session_state_dict, optimizer_model_state_key_unexpected, ["b"]), - (training_session_state_dict, optimizer_state_key_missing, ["Moment_1", "Moment_2"]), - (training_session_state_dict, optimizer_state_key_unexpected, ["another_step"]), - ] - - return input_arguments - - -# Tests - - -def test_empty_state_dict_when_training_session_uninitialized(): - trainer = _create_trainer() - with pytest.warns(UserWarning) as user_warning: - state_dict = trainer.state_dict() - - assert len(state_dict.keys()) == 0 - assert ( - user_warning[0].message.args[0] == "ONNX Runtime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling ORTTrainer.state_dict()." - ) - - -@patch("onnx.ModelProto") -def test_training_session_provides_empty_model_states(onnx_model_mock): - trainer = _create_trainer() - training_session_mock = _training_session_mock({}, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert len(state_dict["model"].keys()) == 0 - - -@patch("onnx.ModelProto") -def test_training_session_provides_model_states(onnx_model_mock): - trainer = _create_trainer() - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - training_session_mock = _training_session_mock(model_states, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all() - assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all() - - -@patch("onnx.ModelProto") -def test_training_session_provides_model_states_pytorch_format(onnx_model_mock): - trainer = _create_trainer() - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - training_session_mock = _training_session_mock(model_states, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict(pytorch_format=True) - assert torch.all(torch.eq(state_dict["a"], torch.tensor(np.arange(5)))) - assert torch.all(torch.eq(state_dict["b"], torch.tensor(np.arange(7)))) - - -@patch("onnx.ModelProto") -def test_onnx_graph_provides_frozen_model_states(onnx_model_mock): - trainer = _create_trainer() - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - training_session_mock = _training_session_mock(model_states, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - trainer.options.utils.frozen_weights = ["a_frozen_weight", "a_float16_weight"] - trainer._onnx_model.graph.initializer = [ - onnx.numpy_helper.from_array(np.array([1, 2, 3], dtype=np.float32), "a_frozen_weight"), - onnx.numpy_helper.from_array(np.array([4, 5, 6], dtype=np.float32), "a_non_fronzen_weight"), - onnx.numpy_helper.from_array(np.array([7, 8, 9], dtype=np.float16), "a_float16_weight"), - ] - - state_dict = trainer.state_dict() - assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all() - assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all() - assert (state_dict["model"]["full_precision"]["a_frozen_weight"] == np.array([1, 2, 3], dtype=np.float32)).all() - assert "a_non_fronzen_weight" not in state_dict["model"]["full_precision"] - assert (state_dict["model"]["full_precision"]["a_float16_weight"] == np.array([7, 8, 9], dtype=np.float32)).all() - - -@patch("onnx.ModelProto") -def test_training_session_provides_empty_optimizer_states(onnx_model_mock): - trainer = _create_trainer() - training_session_mock = _training_session_mock({}, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert len(state_dict["optimizer"].keys()) == 0 - - -@patch("onnx.ModelProto") -def test_training_session_provides_optimizer_states(onnx_model_mock): - trainer = _create_trainer() - optimizer_states = { - "model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(1)}, - } - training_session_mock = _training_session_mock({}, optimizer_states, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert (state_dict["optimizer"]["model_weight"]["Moment_1"] == np.arange(5)).all() - assert (state_dict["optimizer"]["model_weight"]["Moment_2"] == np.arange(7)).all() - assert (state_dict["optimizer"]["shared_optimizer_state"]["step"] == np.arange(1)).all() - - -@patch("onnx.ModelProto") -def test_training_session_provides_optimizer_states_pytorch_format(onnx_model_mock): - trainer = _create_trainer() - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - optimizer_states = { - "model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(1)}, - } - training_session_mock = _training_session_mock(model_states, optimizer_states, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict(pytorch_format=True) - assert "optimizer" not in state_dict - - -@patch("onnx.ModelProto") -def test_training_session_provides_empty_partition_info_map(onnx_model_mock): - trainer = _create_trainer(zero_enabled=True) - training_session_mock = _training_session_mock({}, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert len(state_dict["partition_info"].keys()) == 0 - - -@patch("onnx.ModelProto") -def test_training_session_provides_partition_info_map(onnx_model_mock): - trainer = _create_trainer(zero_enabled=True) - partition_info = {"a": {"original_dim": [1, 2, 3]}} - training_session_mock = _training_session_mock({}, {}, partition_info) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert state_dict["partition_info"]["a"]["original_dim"] == [1, 2, 3] - - -@patch("onnx.ModelProto") -def test_training_session_provides_all_states(onnx_model_mock): - trainer = _create_trainer(zero_enabled=True) - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - optimizer_states = { - "model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(1)}, - } - partition_info = {"a": {"original_dim": [1, 2, 3]}} - training_session_mock = _training_session_mock(model_states, optimizer_states, partition_info) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all() - assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all() - assert (state_dict["optimizer"]["model_weight"]["Moment_1"] == np.arange(5)).all() - assert (state_dict["optimizer"]["model_weight"]["Moment_2"] == np.arange(7)).all() - assert (state_dict["optimizer"]["shared_optimizer_state"]["step"] == np.arange(1)).all() - assert state_dict["partition_info"]["a"]["original_dim"] == [1, 2, 3] - - -def test_load_state_dict_holds_when_training_session_not_initialized(): - trainer = _create_trainer() - state_dict = { - "model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}, - "optimizer": { - "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(5)}, - }, - } - assert not trainer._load_state_dict - state_dict = trainer.load_state_dict(state_dict) - assert trainer._load_state_dict - - -@pytest.mark.parametrize( - "state_dict, input_state_dict, error_key", - [ - ( - {"model": {}, "optimizer": {}}, - {"model": {}, "optimizer": {}, "trainer_options": {"optimizer_name": "LambOptimizer"}}, - "train_step_info", - ), - ( - {"optimizer": {}, "train_step_info": {"optimization_step": 0, "step": 0}}, - { - "optimizer": {}, - "trainer_options": {"optimizer_name": "LambOptimizer"}, - "train_step_info": {"optimization_step": 0, "step": 0}, - }, - "model", - ), - ( - {"model": {}, "train_step_info": {"optimization_step": 0, "step": 0}}, - { - "model": {}, - "trainer_options": {"optimizer_name": "LambOptimizer"}, - "train_step_info": {"optimization_step": 0, "step": 0}, - }, - "optimizer", - ), - ], -) -def test_load_state_dict_warns_when_model_optimizer_key_missing(state_dict, input_state_dict, error_key): - trainer = _create_trainer() - trainer._training_session = _training_session_mock({}, {}, {}) - trainer.state_dict = Mock(return_value=state_dict) - trainer._update_onnx_model_initializers = Mock() - trainer._init_session = Mock() - with patch("onnx.ModelProto") as onnx_model_mock: - trainer._onnx_model = onnx_model_mock() - trainer._onnx_model.graph.initializer = [] - with pytest.warns(UserWarning) as user_warning: - trainer.load_state_dict(input_state_dict) - - assert user_warning[0].message.args[0] == f"Missing key: {error_key} in state_dict" - - -@pytest.mark.parametrize("state_dict, input_state_dict, error_keys", _get_load_state_dict_strict_error_arguments()) -def test_load_state_dict_errors_when_state_dict_mismatch(state_dict, input_state_dict, error_keys): - trainer = _create_trainer() - trainer._training_session = _training_session_mock({}, {}, {}) - trainer.state_dict = Mock(return_value=state_dict) - with pytest.raises(RuntimeError) as runtime_error: - trainer.load_state_dict(input_state_dict) - - assert any(key in str(runtime_error.value) for key in error_keys) - - -@patch("onnx.ModelProto") -def test_load_state_dict_loads_the_states_and_inits_training_session(onnx_model_mock): - trainer = _create_trainer() - training_session_state_dict = { - "model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}, - "optimizer": { - "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(1)}, - }, - } - - input_state_dict = { - "model": {"full_precision": {"a": np.array([1, 2]), "b": np.array([3, 4])}}, - "optimizer": { - "a": {"Moment_1": np.array([5, 6]), "Moment_2": np.array([7, 8])}, - "shared_optimizer_state": {"step": np.array([9])}, - }, - "trainer_options": {"optimizer_name": "LambOptimizer"}, - } - trainer._training_session = _training_session_mock({}, {}, {}) - trainer.state_dict = Mock(return_value=training_session_state_dict) - trainer._onnx_model = onnx_model_mock() - trainer._onnx_model.graph.initializer = [ - onnx.numpy_helper.from_array(np.arange(20, dtype=np.float32), "a"), - onnx.numpy_helper.from_array(np.arange(25, dtype=np.float32), "b"), - ] - trainer._update_onnx_model_initializers = Mock() - trainer._init_session = Mock() - - trainer.load_state_dict(input_state_dict) - - loaded_initializers, _ = trainer._update_onnx_model_initializers.call_args - state_dict_to_load, _ = trainer._init_session.call_args - - assert "a" in loaded_initializers[0] - assert (loaded_initializers[0]["a"] == np.array([1, 2])).all() - assert "b" in loaded_initializers[0] - assert (loaded_initializers[0]["b"] == np.array([3, 4])).all() - - assert (state_dict_to_load[0]["a"]["Moment_1"] == np.array([5, 6])).all() - assert (state_dict_to_load[0]["a"]["Moment_2"] == np.array([7, 8])).all() - assert (state_dict_to_load[0]["shared_optimizer_state"]["step"] == np.array([9])).all() - - -@patch("onnxruntime.training._checkpoint_storage.save") -def test_save_checkpoint_calls_checkpoint_storage_save(save_mock): - trainer = _create_trainer() - state_dict = {"model": {}, "optimizer": {}} - trainer.state_dict = Mock(return_value=state_dict) - - trainer.save_checkpoint("abc") - - save_args, _ = save_mock.call_args - assert "model" in save_args[0] - assert not bool(save_args[0]["model"]) - assert "optimizer" in save_args[0] - assert not bool(save_args[0]["optimizer"]) - assert save_args[1] == "abc" - - -@patch("onnxruntime.training._checkpoint_storage.save") -def test_save_checkpoint_exclude_optimizer_states(save_mock): - trainer = _create_trainer() - state_dict = {"model": {}, "optimizer": {}} - trainer.state_dict = Mock(return_value=state_dict) - - trainer.save_checkpoint("abc", include_optimizer_states=False) - - save_args, _ = save_mock.call_args - assert "model" in save_args[0] - assert not bool(save_args[0]["model"]) - assert "optimizer" not in save_args[0] - assert save_args[1] == "abc" - - -@patch("onnxruntime.training._checkpoint_storage.save") -def test_save_checkpoint_user_dict(save_mock): - trainer = _create_trainer() - state_dict = {"model": {}, "optimizer": {}} - trainer.state_dict = Mock(return_value=state_dict) - - trainer.save_checkpoint("abc", user_dict={"abc": np.arange(4)}) - - save_args, _ = save_mock.call_args - assert "user_dict" in save_args[0] - assert save_args[0]["user_dict"] == _checkpoint_storage.to_serialized_hex({"abc": np.arange(4)}) - - -@patch("onnxruntime.training._checkpoint_storage.load") -@patch("onnxruntime.training.checkpoint.aggregate_checkpoints") -def test_load_checkpoint(aggregate_checkpoints_mock, load_mock): - trainer = _create_trainer() - trainer_options = { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - } - state_dict = { - "model": {}, - "optimizer": {}, - "trainer_options": { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - }, - } - trainer.load_state_dict = Mock() - - load_mock.side_effect = [trainer_options, state_dict] - trainer.load_checkpoint("abc") - - args_list = load_mock.call_args_list - load_args, load_kwargs = args_list[0] - assert load_args[0] == "abc" - assert load_kwargs["key"] == "trainer_options" - load_args, load_kwargs = args_list[1] - assert load_args[0] == "abc" - assert "key" not in load_kwargs - assert not aggregate_checkpoints_mock.called - - -@patch("onnxruntime.training._checkpoint_storage.load") -@patch("onnxruntime.training.checkpoint.aggregate_checkpoints") -@pytest.mark.parametrize( - "trainer_options", - [ - { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(4), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(4), - "zero_stage": np.int64(1), - }, - { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(1), - }, - { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(1), - }, - ], -) -def test_load_checkpoint_aggregation_required_zero_enabled(aggregate_checkpoints_mock, load_mock, trainer_options): - trainer = _create_trainer() - trainer.load_state_dict = Mock() - - load_mock.side_effect = [trainer_options] - trainer.load_checkpoint("abc") - - args_list = load_mock.call_args_list - load_args, load_kwargs = args_list[0] - assert load_args[0] == "abc" - assert load_kwargs["key"] == "trainer_options" - assert aggregate_checkpoints_mock.called - call_args, _ = aggregate_checkpoints_mock.call_args - assert call_args[0] == tuple(["abc"]) - - -@patch("onnxruntime.training._checkpoint_storage.load") -@patch("onnxruntime.training.checkpoint.aggregate_checkpoints") -def test_load_checkpoint_user_dict(aggregate_checkpoints_mock, load_mock): - trainer = _create_trainer() - trainer_options = { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - } - state_dict = { - "model": {}, - "optimizer": {}, - "trainer_options": { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - }, - "user_dict": _checkpoint_storage.to_serialized_hex({"array": torch.tensor(np.arange(5))}), - } - trainer.load_state_dict = Mock() - - load_mock.side_effect = [trainer_options, state_dict] - user_dict = trainer.load_checkpoint("abc") - - assert torch.all(torch.eq(user_dict["array"], torch.tensor(np.arange(5)))) - - -@patch("onnxruntime.training._checkpoint_storage.load") -def test_checkpoint_aggregation(load_mock): - trainer_options1 = { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(2), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(2), - "zero_stage": np.int64(1), - "optimizer_name": b"Adam", - } - trainer_options2 = { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(1), - "world_size": np.int64(2), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(2), - "zero_stage": np.int64(1), - "optimizer_name": b"Adam", - } - - state_dict1 = { - "model": {"full_precision": {"optimizer_sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}}, - "optimizer": { - "optimizer_sharded": { - "Moment_1": np.array([9, 8, 7]), - "Moment_2": np.array([99, 88, 77]), - "Step": np.array([5]), - }, - "non_sharded": { - "Moment_1": np.array([666, 555, 444]), - "Moment_2": np.array([6666, 5555, 4444]), - "Step": np.array([55]), - }, - }, - "trainer_options": { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - "optimizer_name": b"Adam", - }, - "partition_info": {"optimizer_sharded": {"original_dim": np.array([2, 3])}}, - } - - state_dict2 = { - "model": {"full_precision": {"optimizer_sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}}, - "optimizer": { - "optimizer_sharded": { - "Moment_1": np.array([6, 5, 4]), - "Moment_2": np.array([66, 55, 44]), - "Step": np.array([5]), - }, - "non_sharded": { - "Moment_1": np.array([666, 555, 444]), - "Moment_2": np.array([6666, 5555, 4444]), - "Step": np.array([55]), - }, - }, - "trainer_options": { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(1), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - "optimizer_name": b"Adam", - }, - "partition_info": {"optimizer_sharded": {"original_dim": np.array([2, 3])}}, - } - - load_mock.side_effect = [trainer_options1, trainer_options2, trainer_options1, state_dict1, state_dict2] - state_dict = checkpoint.aggregate_checkpoints(["abc", "def"], pytorch_format=False) - - assert (state_dict["model"]["full_precision"]["optimizer_sharded"] == np.array([1, 2, 3])).all() - assert (state_dict["model"]["full_precision"]["non_sharded"] == np.array([11, 22, 33])).all() - assert (state_dict["optimizer"]["optimizer_sharded"]["Moment_1"] == np.array([[9, 8, 7], [6, 5, 4]])).all() - assert (state_dict["optimizer"]["optimizer_sharded"]["Moment_2"] == np.array([[99, 88, 77], [66, 55, 44]])).all() - assert (state_dict["optimizer"]["optimizer_sharded"]["Step"] == np.array([5])).all() - assert (state_dict["optimizer"]["non_sharded"]["Moment_1"] == np.array([666, 555, 444])).all() - assert (state_dict["optimizer"]["non_sharded"]["Moment_2"] == np.array([6666, 5555, 4444])).all() - assert (state_dict["optimizer"]["non_sharded"]["Step"] == np.array([55])).all() - - assert state_dict["trainer_options"]["mixed_precision"] is False - assert state_dict["trainer_options"]["world_rank"] == 0 - assert state_dict["trainer_options"]["world_size"] == 1 - assert state_dict["trainer_options"]["horizontal_parallel_size"] == 1 - assert state_dict["trainer_options"]["data_parallel_size"] == 1 - assert state_dict["trainer_options"]["zero_stage"] == 0 - assert state_dict["trainer_options"]["optimizer_name"] == b"Adam" - - -@patch("onnxruntime.training._checkpoint_storage.load") -def test_checkpoint_aggregation_mixed_precision(load_mock): - trainer_options1 = { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(0), - "world_size": np.int64(2), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(2), - "zero_stage": np.int64(1), - "optimizer_name": b"Adam", - } - trainer_options2 = { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(1), - "world_size": np.int64(2), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(2), - "zero_stage": np.int64(1), - "optimizer_name": b"Adam", - } - - state_dict1 = { - "model": {"full_precision": {"sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}}, - "optimizer": { - "sharded": {"Moment_1": np.array([9, 8, 7]), "Moment_2": np.array([99, 88, 77]), "Step": np.array([5])}, - "non_sharded": { - "Moment_1": np.array([666, 555, 444]), - "Moment_2": np.array([6666, 5555, 4444]), - "Step": np.array([55]), - }, - }, - "trainer_options": { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - "optimizer_name": b"Adam", - }, - "partition_info": {"sharded": {"original_dim": np.array([2, 3])}}, - } - - state_dict2 = { - "model": {"full_precision": {"sharded": np.array([4, 5, 6]), "non_sharded": np.array([11, 22, 33])}}, - "optimizer": { - "sharded": {"Moment_1": np.array([6, 5, 4]), "Moment_2": np.array([66, 55, 44]), "Step": np.array([5])}, - "non_sharded": { - "Moment_1": np.array([666, 555, 444]), - "Moment_2": np.array([6666, 5555, 4444]), - "Step": np.array([55]), - }, - }, - "trainer_options": { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(1), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - "optimizer_name": b"Adam", - }, - "partition_info": {"sharded": {"original_dim": np.array([2, 3])}}, - } - - load_mock.side_effect = [trainer_options1, trainer_options2, trainer_options1, state_dict1, state_dict2] - state_dict = checkpoint.aggregate_checkpoints(["abc", "def"], pytorch_format=False) - - assert (state_dict["model"]["full_precision"]["sharded"] == np.array([[1, 2, 3], [4, 5, 6]])).all() - assert (state_dict["model"]["full_precision"]["non_sharded"] == np.array([11, 22, 33])).all() - assert (state_dict["optimizer"]["sharded"]["Moment_1"] == np.array([[9, 8, 7], [6, 5, 4]])).all() - assert (state_dict["optimizer"]["sharded"]["Moment_2"] == np.array([[99, 88, 77], [66, 55, 44]])).all() - assert (state_dict["optimizer"]["sharded"]["Step"] == np.array([5])).all() - assert (state_dict["optimizer"]["non_sharded"]["Moment_1"] == np.array([666, 555, 444])).all() - assert (state_dict["optimizer"]["non_sharded"]["Moment_2"] == np.array([6666, 5555, 4444])).all() - assert (state_dict["optimizer"]["non_sharded"]["Step"] == np.array([55])).all() - - assert state_dict["trainer_options"]["mixed_precision"] is True - assert state_dict["trainer_options"]["world_rank"] == 0 - assert state_dict["trainer_options"]["world_size"] == 1 - assert state_dict["trainer_options"]["horizontal_parallel_size"] == 1 - assert state_dict["trainer_options"]["data_parallel_size"] == 1 - assert state_dict["trainer_options"]["zero_stage"] == 0 - assert state_dict["trainer_options"]["optimizer_name"] == b"Adam" diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py deleted file mode 100644 index fa13625f0ddac..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ /dev/null @@ -1,2460 +0,0 @@ -import inspect -import os -import tempfile -from functools import partial - -import _test_commons -import _test_helpers -import onnx -import pytest -import torch -import torch.nn.functional as F -from numpy.testing import assert_allclose -from packaging.version import Version as StrictVersion - -from onnxruntime import SessionOptions, set_seed -from onnxruntime.capi.ort_trainer import LossScaler as Legacy_LossScaler -from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer -from onnxruntime.training import PropagateCastOpsStrategy, TrainStepInfo, _utils, amp -from onnxruntime.training import model_desc_validation as md_val -from onnxruntime.training import optim, orttrainer, orttrainer_options - -############################################################################### -# Testing starts here ######################################################### -############################################################################### - -pytorch_110 = StrictVersion(".".join(torch.__version__.split(".")[:2])) >= StrictVersion("1.10.0") - - -def get_model_opset(model_onnx): - for op in model_onnx.opset_import: - if op.domain == "": - return op.version - return None - - -@pytest.mark.parametrize( - "test_input", - [({}), ({"batch": {}, "device": {}, "distributed": {}, "mixed_precision": {}, "utils": {}, "_internal_use": {}})], -) -def testORTTrainerOptionsDefaultValues(test_input): - """Test different ways of using default values for incomplete input""" - - expected_values = { - "batch": {"gradient_accumulation_steps": 1}, - "device": {"id": "cuda", "mem_limit": 0}, - "distributed": { - "world_rank": 0, - "world_size": 1, - "local_rank": 0, - "data_parallel_size": 1, - "horizontal_parallel_size": 1, - "pipeline_parallel": { - "pipeline_parallel_size": 1, - "num_pipeline_micro_batches": 1, - "pipeline_cut_info_string": "", - "sliced_schema": {}, - "sliced_axes": {}, - "sliced_tensor_names": [], - }, - "allreduce_post_accumulation": False, - "deepspeed_zero_optimization": { - "stage": 0, - }, - "enable_adasum": False, - }, - "lr_scheduler": None, - "mixed_precision": {"enabled": False, "loss_scaler": None}, - "graph_transformer": { - "attn_dropout_recompute": False, - "gelu_recompute": False, - "transformer_layer_recompute": False, - "number_recompute_layers": 0, - "propagate_cast_ops_config": {"strategy": PropagateCastOpsStrategy.FLOOD_FILL, "level": 1, "allow": []}, - }, - "utils": { - "frozen_weights": [], - "grad_norm_clip": True, - "memory_efficient_gradient": False, - "run_symbolic_shape_infer": False, - }, - "debug": { - "deterministic_compute": False, - "check_model_export": False, - "graph_save_paths": { - "model_after_graph_transforms_path": "", - "model_with_gradient_graph_path": "", - "model_with_training_graph_path": "", - "model_with_training_graph_after_optimization_path": "", - }, - }, - "_internal_use": { - "enable_internal_postprocess": True, - "extra_postprocess": None, - "onnx_opset_version": 14, - "enable_onnx_contrib_ops": True, - }, - "provider_options": {}, - "session_options": None, - } - - actual_values = orttrainer_options.ORTTrainerOptions(test_input) - assert actual_values._validated_opts == expected_values - - -@pytest.mark.parametrize( - "input,error_msg", - [ - ( - {"mixed_precision": {"enabled": 1}}, - "Invalid options: {'mixed_precision': [{'enabled': ['must be of boolean type']}]}", - ) - ], -) -def testORTTrainerOptionsInvalidMixedPrecisionEnabledSchema(input, error_msg): - """Test an invalid input based on schema validation error message""" - - with pytest.raises(ValueError) as e: - orttrainer_options.ORTTrainerOptions(input) - assert str(e.value) == error_msg - - -@pytest.mark.parametrize( - "input_dict,input_dtype,output_dtype", - [ - ( - {"inputs": [("in0", [])], "outputs": [("out0", []), ("out1", [])]}, - (torch.int,), - ( - torch.float, - torch.int32, - ), - ), - ({"inputs": [("in0", ["batch", 2, 3])], "outputs": [("out0", [], True)]}, (torch.int8,), (torch.int16,)), - ( - { - "inputs": [ - ("in0", []), - ("in1", [1]), - ("in2", [1, 2]), - ("in3", [1000, "dyn_ax1"]), - ("in4", ["dyn_ax1", "dyn_ax2", "dyn_ax3"]), - ], - "outputs": [("out0", [], True), ("out1", [1], False), ("out2", [1, "dyn_ax1", 3])], - }, - ( - torch.float, - torch.uint8, - torch.bool, - torch.double, - torch.half, - ), - (torch.float, torch.float, torch.int64), - ), - ], -) -def testORTTrainerModelDescValidSchemas(input_dict, input_dtype, output_dtype): - r"""Test different ways of using default values for incomplete input""" - - model_description = md_val._ORTTrainerModelDesc(input_dict) - - # Validating hard-coded learning rate description - assert model_description.learning_rate.name == md_val.LEARNING_RATE_IO_DESCRIPTION_NAME - assert model_description.learning_rate.shape == [1] - assert model_description.learning_rate.dtype == torch.float32 - - # Validating model description from user - for idx, i_desc in enumerate(model_description.inputs): - assert isinstance(i_desc, model_description._InputDescription) - assert len(i_desc) == 2 - assert input_dict["inputs"][idx][0] == i_desc.name - assert input_dict["inputs"][idx][1] == i_desc.shape - for idx, o_desc in enumerate(model_description.outputs): - assert isinstance(o_desc, model_description._OutputDescription) - assert len(o_desc) == 3 - assert input_dict["outputs"][idx][0] == o_desc.name - assert input_dict["outputs"][idx][1] == o_desc.shape - is_loss = input_dict["outputs"][idx][2] if len(input_dict["outputs"][idx]) == 3 else False - assert is_loss == o_desc.is_loss - - # Set all_finite name and check its description - model_description.all_finite = md_val.ALL_FINITE_IO_DESCRIPTION_NAME - assert model_description.all_finite.name == md_val.ALL_FINITE_IO_DESCRIPTION_NAME - assert model_description.all_finite.shape == [1] - assert model_description.all_finite.dtype == torch.bool - - # Set loss_scale_input and check its description - model_description.loss_scale_input = md_val.LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME - assert model_description.loss_scale_input.name == md_val.LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME - assert model_description.loss_scale_input.shape == [] - assert model_description.loss_scale_input.dtype == torch.float32 - - # Append type to inputs/outputs tuples - for idx, i_desc in enumerate(model_description.inputs): # noqa: B007 - model_description.add_type_to_input_description(idx, input_dtype[idx]) - for idx, o_desc in enumerate(model_description.outputs): # noqa: B007 - model_description.add_type_to_output_description(idx, output_dtype[idx]) - - # Verify inputs/outputs tuples are replaced by the typed counterparts - for idx, i_desc in enumerate(model_description.inputs): - assert isinstance(i_desc, model_description._InputDescriptionTyped) - assert input_dtype[idx] == i_desc.dtype - for idx, o_desc in enumerate(model_description.outputs): - assert isinstance(o_desc, model_description._OutputDescriptionTyped) - assert output_dtype[idx] == o_desc.dtype - - -@pytest.mark.parametrize( - "input_dict,error_msg", - [ - ( - {"inputs": [(True, [])], "outputs": [(True, [])]}, - "Invalid model_desc: {'inputs': [{0: ['the first element of the tuple (aka name) must be a string']}], " - "'outputs': [{0: ['the first element of the tuple (aka name) must be a string']}]}", - ), - ( - {"inputs": [("in1", None)], "outputs": [("out1", None)]}, - "Invalid model_desc: {'inputs': [{0: ['the second element of the tuple (aka shape) must be a list']}], " - "'outputs': [{0: ['the second element of the tuple (aka shape) must be a list']}]}", - ), - ( - {"inputs": [("in1", [])], "outputs": [("out1", [], None)]}, - "Invalid model_desc: {'outputs': [{0: ['the third element of the tuple (aka is_loss) must be a boolean']}]}", - ), - ( - {"inputs": [("in1", [True])], "outputs": [("out1", [True])]}, - "Invalid model_desc: {'inputs': [{0: ['each shape must be either a string or integer']}], " - "'outputs': [{0: ['each shape must be either a string or integer']}]}", - ), - ( - {"inputs": [("in1", [])], "outputs": [("out1", [], True), ("out2", [], True)]}, - "Invalid model_desc: {'outputs': [{1: ['only one is_loss can bet set to True']}]}", - ), - ( - {"inputz": [("in1", [])], "outputs": [("out1", [], True)]}, - "Invalid model_desc: {'inputs': ['required field'], 'inputz': ['unknown field']}", - ), - ( - {"inputs": [("in1", [])], "outputz": [("out1", [], True)]}, - "Invalid model_desc: {'outputs': ['required field'], 'outputz': ['unknown field']}", - ), - ], -) -def testORTTrainerModelDescInvalidSchemas(input_dict, error_msg): - r"""Test different ways of using default values for incomplete input""" - with pytest.raises(ValueError) as e: - md_val._ORTTrainerModelDesc(input_dict) - assert str(e.value) == error_msg - - -def testDynamicLossScaler(): - rtol = 1e-7 - default_scaler = amp.loss_scaler.DynamicLossScaler() - - # Initial state - train_step_info = orttrainer.TrainStepInfo(optim.LambConfig()) - assert_allclose(default_scaler.loss_scale, float(1 << 16), rtol=rtol, err_msg="loss scale mismatch") - assert default_scaler.up_scale_window == 2000 - assert_allclose(default_scaler.min_loss_scale, 1.0, rtol=rtol, err_msg="min loss scale mismatch") - assert_allclose(default_scaler.max_loss_scale, float(1 << 24), rtol=rtol, err_msg="max loss scale mismatch") - - # Performing 9*2000 updates to cover all branches of LossScaler.update(train_step_info.all_finite=True) - loss_scale = float(1 << 16) - for cycles in range(1, 10): - # 1999 updates without overflow produces 1999 stable steps - for i in range(1, 2000): - new_loss_scale = default_scaler.update(train_step_info) - assert default_scaler._stable_steps_count == i - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg=f"loss scale mismatch at update {i}") - - # 2000th update without overflow doubles the loss and zero stable steps until max_loss_scale is reached - new_loss_scale = default_scaler.update(train_step_info) - if cycles <= 8: - loss_scale *= 2 - assert default_scaler._stable_steps_count == 0 - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") - - # After 8 cycles, loss scale should be float(1 << 16)*(2**8) - assert_allclose(new_loss_scale, float(1 << 16) * (2**8), rtol=rtol, err_msg="loss scale mismatch") - - # After 9 cycles, loss scale reaches max_loss_scale and it is not doubled from that point on - loss_scale = float(1 << 16) * (2**8) - for count in range(1, 2050): - new_loss_scale = default_scaler.update(train_step_info) - assert default_scaler._stable_steps_count == (count % 2000) - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") - - # Setting train_step_info.all_finite = False to test down scaling - train_step_info.all_finite = False - - # Performing 24 updates to half the loss scale each time - loss_scale = float(1 << 16) * (2**8) - for count in range(1, 25): # noqa: B007 - new_loss_scale = default_scaler.update(train_step_info) - loss_scale /= 2 - assert default_scaler._stable_steps_count == 0 - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") - - # After 24 updates with gradient overflow, loss scale is 1.0 - assert_allclose(new_loss_scale, 1.0, rtol=rtol, err_msg="loss scale mismatch") - - # After 25 updates, min_loss_scale is reached and loss scale is not halfed from that point on - for count in range(1, 5): # noqa: B007 - new_loss_scale = default_scaler.update(train_step_info) - assert default_scaler._stable_steps_count == 0 - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") - - -def testDynamicLossScalerCustomValues(): - rtol = 1e-7 - scaler = amp.loss_scaler.DynamicLossScaler( - automatic_update=False, loss_scale=3, up_scale_window=7, min_loss_scale=5, max_loss_scale=10 - ) - assert scaler.automatic_update is False - assert_allclose(scaler.loss_scale, 3, rtol=rtol, err_msg="loss scale mismatch") - assert_allclose(scaler.min_loss_scale, 5, rtol=rtol, err_msg="min loss scale mismatch") - assert_allclose(scaler.max_loss_scale, 10, rtol=rtol, err_msg="max loss scale mismatch") - assert scaler.up_scale_window == 7 - - -def testTrainStepInfo(): - """Test valid initializations of TrainStepInfo""" - - optimizer_config = optim.LambConfig() - fetches = ["out1", "out2"] - step_info = orttrainer.TrainStepInfo( - optimizer_config=optimizer_config, all_finite=False, fetches=fetches, optimization_step=123, step=456 - ) - assert step_info.optimizer_config == optimizer_config - assert step_info.all_finite is False - assert step_info.fetches == fetches - assert step_info.optimization_step == 123 - assert step_info.step == 456 - - step_info = orttrainer.TrainStepInfo(optimizer_config) - assert step_info.optimizer_config == optimizer_config - assert step_info.all_finite is True - assert step_info.fetches == [] - assert step_info.optimization_step == 0 - assert step_info.step == 0 - - -@pytest.mark.parametrize( - "invalid_input", - [ - (-1), - ("Hello"), - ], -) -def testTrainStepInfoInvalidInput(invalid_input): - """Test invalid initialization of TrainStepInfo""" - optimizer_config = optim.LambConfig() - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config=invalid_input) - - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config, all_finite=invalid_input) - - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config, fetches=invalid_input) - - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config, optimization_step=invalid_input) - - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config, step=invalid_input) - - -@pytest.mark.parametrize( - "optim_name,lr,alpha,default_alpha", - [ - ("AdamOptimizer", 0.1, 0.2, None), - ("LambOptimizer", 0.2, 0.3, None), - ("SGDOptimizer", 0.3, 0.4, None), - ("SGDOptimizer", 0.3, 0.4, 0.5), - ], -) -def testOptimizerConfig(optim_name, lr, alpha, default_alpha): - """Test initialization of _OptimizerConfig""" - defaults = {"lr": lr, "alpha": alpha} - params = [{"params": ["fc1.weight", "fc2.weight"]}] - if default_alpha is not None: - params[0].update({"alpha": default_alpha}) - else: - params[0].update({"alpha": alpha}) - cfg = optim.config._OptimizerConfig(name=optim_name, params=params, defaults=defaults) - - assert cfg.name == optim_name - rtol = 1e-07 - assert_allclose(defaults["lr"], cfg.lr, rtol=rtol, err_msg="lr mismatch") - - # 1:1 mapping between defaults and params's hyper parameters - for param in params: - for k in param: - if k != "params": - assert k in cfg.defaults, "hyper parameter {k} not present in one of the parameter params" - for k in cfg.defaults: - for param in cfg.params: - assert k in param, "hyper parameter {k} not present in one of the parameter params" - - -@pytest.mark.parametrize( - "optim_name,defaults,params", - [ - ("AdamOptimizer", {"lr": -1}, []), # invalid lr - ("FooOptimizer", {"lr": 0.001}, []), # invalid name - ("SGDOptimizer", [], []), # invalid type(defaults) - (optim.AdamConfig, {"lr": 0.003}, []), # invalid type(name) - ("AdamOptimizer", {"lr": None}, []), # missing 'lr' hyper parameter - ("SGDOptimizer", {"lr": 0.004}, {}), # invalid type(params) - # invalid type(params[i]) - ("AdamOptimizer", {"lr": 0.005, "alpha": 2}, [[]]), - # missing 'params' at 'params' - ("AdamOptimizer", {"lr": 0.005, "alpha": 2}, [{"alpha": 1}]), - # missing 'alpha' at 'defaults' - ("AdamOptimizer", {"lr": 0.005}, [{"params": "param1", "alpha": 1}]), - ], -) -def testOptimizerConfigInvalidInputs(optim_name, defaults, params): - """Test invalid initialization of _OptimizerConfig""" - - with pytest.raises(AssertionError): - optim.config._OptimizerConfig(name=optim_name, params=params, defaults=defaults) - - -def testOptimizerConfigSGD(): - """Test initialization of SGD""" - cfg = optim.SGDConfig() - assert cfg.name == "SGDOptimizer" - - rtol = 1e-07 - assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") - - cfg = optim.SGDConfig(lr=0.002) - assert_allclose(0.002, cfg.lr, rtol=rtol, err_msg="lr mismatch") - - # SGD does not support params - with pytest.raises(AssertionError) as e: - params = [{"params": ["layer1.weight"], "lr": 0.1}] - optim.SGDConfig(params=params, lr=0.002) - assert_allclose(0.002, cfg.lr, rtol=rtol, err_msg="lr mismatch") - assert str(e.value) == "'params' must be an empty list for SGD optimizer" - - -def testOptimizerConfigAdam(): - """Test initialization of Adam""" - cfg = optim.AdamConfig() - assert cfg.name == "AdamOptimizer" - - rtol = 1e-7 - assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") - assert_allclose(0.9, cfg.alpha, rtol=rtol, err_msg="alpha mismatch") - assert_allclose(0.999, cfg.beta, rtol=rtol, err_msg="beta mismatch") - assert_allclose(0.0, cfg.lambda_coef, rtol=rtol, err_msg="lambda_coef mismatch") - assert_allclose(1e-8, cfg.epsilon, rtol=rtol, err_msg="epsilon mismatch") - assert_allclose(1.0, cfg.max_norm_clip, rtol=rtol, err_msg="max_norm_clip mismatch") - assert cfg.do_bias_correction is True, "lambda_coef mismatch" - assert cfg.weight_decay_mode == optim.AdamConfig.DecayMode.BEFORE_WEIGHT_UPDATE, "weight_decay_mode mismatch" - - -def testOptimizerConfigLamb(): - """Test initialization of Lamb""" - cfg = optim.LambConfig() - assert cfg.name == "LambOptimizer" - rtol = 1e-7 - assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") - assert_allclose(0.9, cfg.alpha, rtol=rtol, err_msg="alpha mismatch") - assert_allclose(0.999, cfg.beta, rtol=rtol, err_msg="beta mismatch") - assert_allclose(0.0, cfg.lambda_coef, rtol=rtol, err_msg="lambda_coef mismatch") - assert cfg.ratio_min == float("-inf"), "ratio_min mismatch" - assert cfg.ratio_max == float("inf"), "ratio_max mismatch" - assert_allclose(1e-6, cfg.epsilon, rtol=rtol, err_msg="epsilon mismatch") - assert_allclose(1.0, cfg.max_norm_clip, rtol=rtol, err_msg="max_norm_clip mismatch") - assert cfg.do_bias_correction is False, "do_bias_correction mismatch" - - -@pytest.mark.parametrize("optim_name", [("Adam"), ("Lamb")]) -def testOptimizerConfigParams(optim_name): - rtol = 1e-7 - params = [{"params": ["layer1.weight"], "alpha": 0.1}] - if optim_name == "Adam": - cfg = optim.AdamConfig(params=params, alpha=0.2) - elif optim_name == "Lamb": - cfg = optim.LambConfig(params=params, alpha=0.2) - else: - raise ValueError("invalid input") - assert len(cfg.params) == 1, "params should have length 1" - assert_allclose(cfg.params[0]["alpha"], 0.1, rtol=rtol, err_msg="invalid lr on params[0]") - - -@pytest.mark.parametrize("optim_name", [("Adam"), ("Lamb")]) -def testOptimizerConfigInvalidParams(optim_name): - # lr is not supported within params - with pytest.raises(AssertionError) as e: - params = [{"params": ["layer1.weight"], "lr": 0.1}] - if optim_name == "Adam": - optim.AdamConfig(params=params, lr=0.2) - elif optim_name == "Lamb": - optim.LambConfig(params=params, lr=0.2) - else: - raise ValueError("invalid input") - assert str(e.value) == "'lr' is not supported inside params" - - -def testLinearLRSchedulerCreation(): - total_steps = 10 - warmup = 0.05 - - lr_scheduler = optim.lr_scheduler.LinearWarmupLRScheduler(total_steps, warmup) - - # Initial state - assert lr_scheduler.total_steps == total_steps - assert lr_scheduler.warmup == warmup - - -@pytest.mark.parametrize( - "lr_scheduler,expected_values", - [ - (optim.lr_scheduler.ConstantWarmupLRScheduler, [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0]), - ( - optim.lr_scheduler.CosineWarmupLRScheduler, - [ - 0.0, - 0.9763960957919413, - 0.9059835861602854, - 0.7956724530494887, - 0.6563036824392345, - 0.5015739416158049, - 0.34668951940611276, - 0.2068719061737831, - 0.09586187986225325, - 0.0245691111902418, - ], - ), - (optim.lr_scheduler.LinearWarmupLRScheduler, [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.8, 0.6, 0.4, 0.2]), - ( - optim.lr_scheduler.PolyWarmupLRScheduler, - [ - 0.0, - 0.9509018036072144, - 0.9008016032064128, - 0.8507014028056112, - 0.8006012024048097, - 0.750501002004008, - 0.7004008016032064, - 0.6503006012024048, - 0.6002004008016032, - 0.5501002004008015, - ], - ), - ], -) -def testLRSchedulerUpdateImpl(lr_scheduler, expected_values): - # Test tolerance - rtol = 1e-03 - - # Initial state - initial_lr = 1 - total_steps = 10 - warmup = 0.5 - optimizer_config = optim.SGDConfig(lr=initial_lr) - lr_scheduler = lr_scheduler(total_steps, warmup) - - # First half is warmup - for optimization_step in range(total_steps): - # Emulate ORTTRainer.train_step() call that updates its train_step_info - train_step_info = TrainStepInfo(optimizer_config=optimizer_config, optimization_step=optimization_step) - - lr_scheduler._step(train_step_info) - lr_list = lr_scheduler.get_last_lr() - assert len(lr_list) == 1 - assert_allclose(lr_list[0], expected_values[optimization_step], rtol=rtol, err_msg="lr mismatch") - - -def testInstantiateORTTrainerOptions(): - session_options = SessionOptions() - session_options.enable_mem_pattern = False - provider_options = {"EP1": {"key": "val"}} - opts = {"session_options": session_options, "provider_options": provider_options} - opts = orttrainer.ORTTrainerOptions(opts) - assert opts.session_options.enable_mem_pattern is False - assert opts._validated_opts["provider_options"]["EP1"]["key"] == "val" - - -@pytest.mark.parametrize( - "step_fn, lr_scheduler, expected_lr_values, device", - [ - ("train_step", None, None, "cuda"), - ("eval_step", None, None, "cpu"), - ( - "train_step", - optim.lr_scheduler.ConstantWarmupLRScheduler, - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0], - "cpu", - ), - ( - "train_step", - optim.lr_scheduler.CosineWarmupLRScheduler, - [ - 0.0, - 0.2, - 0.4, - 0.6, - 0.8, - 1.0, - 0.9045084971874737, - 0.6545084971874737, - 0.34549150281252633, - 0.09549150281252633, - ], - "cuda", - ), - ( - "train_step", - optim.lr_scheduler.LinearWarmupLRScheduler, - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.8, 0.6, 0.4, 0.2], - "cpu", - ), - ( - "train_step", - optim.lr_scheduler.PolyWarmupLRScheduler, - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.80000002, 0.60000004, 0.40000006000000005, 0.20000007999999997], - "cuda", - ), - ], -) -def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values, device): - total_steps = 1 - initial_lr = 1.0 - rtol = 1e-3 - - # PyTorch Transformer model as example - opts = {"device": {"id": device}} - if lr_scheduler: - total_steps = 10 - opts.update({"lr_scheduler": lr_scheduler(total_steps=total_steps, warmup=0.5)}) - opts = orttrainer.ORTTrainerOptions(opts) - optim_config = optim.LambConfig(lr=initial_lr) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _test_commons._load_pytorch_transformer_model( - device - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - - # Run a train or evaluation step - if step_fn == "eval_step": - data, targets = batcher_fn(val_data, 0) - elif step_fn == "train_step": - data, targets = batcher_fn(train_data, 0) - else: - raise ValueError("Invalid step_fn") - - # Export model to ONNX - if step_fn == "eval_step": - step_fn = trainer.eval_step - output = trainer.eval_step(data, targets) - elif step_fn == "train_step": - step_fn = trainer.train_step - for i in range(total_steps): - output = trainer.train_step(data, targets) - if lr_scheduler: - lr_list = trainer.options.lr_scheduler.get_last_lr() - assert_allclose(lr_list[0], expected_lr_values[i], rtol=rtol, err_msg="lr mismatch") - else: - raise ValueError("Invalid step_fn") - assert trainer._onnx_model is not None - - # Check output shape after train/eval step - for out, desc in zip(output, trainer.model_desc.outputs): - if trainer.loss_fn and desc.is_loss: - continue - assert list(out.size()) == desc.shape - - # Check name, shape and dtype of the first len(forward.parameters) ORT graph inputs - sig = inspect.signature(model.forward) - for i in range(len(sig.parameters.keys())): - input_name = trainer.model_desc.inputs[i][0] - input_dim = trainer.model_desc.inputs[i][1] - input_type = trainer.model_desc.inputs[i][2] - - assert trainer._onnx_model.graph.input[i].name == input_name - for dim_idx, dim in enumerate(trainer._onnx_model.graph.input[i].type.tensor_type.shape.dim): - assert input_dim[dim_idx] == dim.dim_value - assert input_type == _utils.dtype_onnx_to_torch( - trainer._onnx_model.graph.input[i].type.tensor_type.elem_type - ) - - opset = get_model_opset(trainer._onnx_model) - - # Check name, shape and dtype of the ORT graph outputs - for i in range(len(trainer.model_desc.outputs)): - output_name = trainer.model_desc.outputs[i][0] - output_dim = trainer.model_desc.outputs[i][1] - output_type = trainer.model_desc.outputs[i][3] - - assert trainer._onnx_model.graph.output[i].name == output_name - for dim_idx, dim in enumerate(trainer._onnx_model.graph.output[i].type.tensor_type.shape.dim): - if opset is None or opset <= 12: - assert output_dim[dim_idx] == dim.dim_value - assert output_type == _utils.dtype_onnx_to_torch( - trainer._onnx_model.graph.output[i].type.tensor_type.elem_type - ) - - # Save current model as ONNX as a file - file_name = os.path.join("_____temp_onnx_model.onnx") - trainer.save_as_onnx(file_name) - assert os.path.exists(file_name) - with open(file_name, "rb") as f: - bin_str = f.read() - reload_onnx_model = onnx.load_model_from_string(bin_str) - os.remove(file_name) - - # Create a new trainer from persisted ONNX model and compare with original ONNX model - trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config) - step_fn(data, targets) - assert trainer_from_onnx._onnx_model is not None - assert id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model) - assert trainer_from_onnx._onnx_model == trainer._onnx_model - assert trainer_from_onnx._onnx_model.graph == trainer._onnx_model.graph - assert onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph( - trainer._onnx_model.graph - ) - - -@pytest.mark.parametrize("seed, device", [(0, "cpu"), (24, "cuda")]) -def testORTDeterministicCompute(seed, device): - # Common setup - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - {"debug": {"deterministic_compute": True}, "device": {"id": device, "mem_limit": 10 * 1024 * 1024}} - ) - - # Setup for the first ORTTRainer run - torch.manual_seed(seed) - set_seed(seed) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - data, targets = batcher_fn(train_data, 0) - _ = first_trainer.train_step(data, targets) - assert first_trainer._onnx_model is not None - - # Setup for the second ORTTRainer run - torch.manual_seed(seed) - set_seed(seed) - model, _, _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device) - second_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - _ = second_trainer.train_step(data, targets) - assert second_trainer._onnx_model is not None - - # Compare two different instances with identical setup - assert id(first_trainer._onnx_model) != id(second_trainer._onnx_model) - _test_helpers.assert_onnx_weights(first_trainer, second_trainer) - - -@pytest.mark.parametrize( - "seed,device,expected_loss,fetches", - [ - (321, "cuda", [10.5774, 10.4403, 10.4175, 10.2886, 10.2760], False), - (321, "cuda", [10.5774, 10.4403, 10.4175, 10.2886, 10.2760], True), - ], -) -def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches): - return # TODO: re-enable after nondeterminism on backend is fixed. update numbers - - rtol = 1e-3 - total_steps = len(expected_loss) - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - loss_scaler = amp.DynamicLossScaler() - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _test_commons._load_pytorch_transformer_model( - device - ) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - if fetches: - trainer._train_step_info.fetches = ["loss"] - loss = trainer.train_step(data, targets) - else: - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu()) - - # Eval once just to test fetches in action - val_data, val_targets = batcher_fn(val_data, 0) - if fetches: - trainer._train_step_info.fetches = ["loss"] - loss = trainer.eval_step(val_data, val_targets) - trainer._train_step_info.fetches = [] - loss, _ = trainer.eval_step(val_data, val_targets) - - # Compare loss to ground truth computed from current ORTTrainer API - _test_helpers.assert_model_outputs(expected_loss, actual_loss, True, rtol=rtol) - assert trainer._onnx_model is not None - - -def _recompute_data(): - device_capability_major = torch.cuda.get_device_capability()[0] - if device_capability_major == 7: # V100 for Dev machine - expected_loss = { - 12: [10.5598, 10.4591, 10.3477, 10.2726, 10.1945], - 14: [10.54088, 10.498755, 10.386827, 10.338747, 10.262459], - } - return [ - (False, False, False, 0, expected_loss), # no recompute - (True, False, False, 0, expected_loss), # attn_dropout recompute - (False, True, False, 0, expected_loss), # gelu recompute - (False, False, True, 0, expected_loss), # transformer_layer recompute - (False, False, True, 1, expected_loss), # transformer_layer recompute with 1 layer - ] - elif device_capability_major == 5: # M60 for CI machines - expected_loss = { - 12: [10.5445, 10.4389, 10.3480, 10.2627, 10.2113], - 14: [10.5445, 10.4389, 10.3480, 10.2627, 10.2113], - } - return [ - (False, False, False, 0, expected_loss), # no recompute - (True, False, False, 0, expected_loss), # attn_dropout recompute - (False, True, False, 0, expected_loss), # gelu recompute - (False, False, True, 0, expected_loss), # transformer_layer recompute - (False, False, True, 1, expected_loss), # transformer_layer recompute with 1 layer - ] - - -@pytest.mark.parametrize("attn_dropout, gelu, transformer_layer, number_layers, expected_loss", _recompute_data()) -def testORTTrainerRecompute(attn_dropout, gelu, transformer_layer, number_layers, expected_loss): - seed = 321 - device = "cuda" - rtol = 1e-3 - total_steps = len(expected_loss[12]) - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "graph_transformer": { - "attn_dropout_recompute": attn_dropout, - "gelu_recompute": gelu, - "transformer_layer_recompute": transformer_layer, - "number_recompute_layers": number_layers, - }, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _test_commons._load_pytorch_transformer_model( - device - ) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu()) - - # Compare loss to ground truth computed from current ORTTrainer API - assert trainer._onnx_model is not None - opset = get_model_opset(trainer._onnx_model) - _test_helpers.assert_model_outputs(expected_loss[opset], actual_loss, True, rtol=rtol) - - -@pytest.mark.parametrize( - "seed,device,gradient_accumulation_steps,total_steps,expected_loss", - [ - ( - 0, - "cuda", - 1, - 12, - [ - 10.5368022919, - 10.4146203995, - 10.3635568619, - 10.2650547028, - 10.2284049988, - 10.1304626465, - 10.0853414536, - 9.9987659454, - 9.9472427368, - 9.8832416534, - 9.8223171234, - 9.8222122192, - ], - ), - ( - 42, - "cuda", - 3, - 12, - [ - 10.6455879211, - 10.6247081757, - 10.6361322403, - 10.5187482834, - 10.5345087051, - 10.5487670898, - 10.4833698273, - 10.4600019455, - 10.4535751343, - 10.3774127960, - 10.4144191742, - 10.3757553101, - ], - ), - ( - 123, - "cuda", - 7, - 12, - [ - 10.5353469849, - 10.5261383057, - 10.5240392685, - 10.5013713837, - 10.5678377151, - 10.5452117920, - 10.5184345245, - 10.4271221161, - 10.4458627701, - 10.4864749908, - 10.4416503906, - 10.4467563629, - ], - ), - ( - 321, - "cuda", - 12, - 12, - [ - 10.5773944855, - 10.5428829193, - 10.5974750519, - 10.5416746140, - 10.6009902954, - 10.5684127808, - 10.5759754181, - 10.5636739731, - 10.5613927841, - 10.5825119019, - 10.6031589508, - 10.6199369431, - ], - ), - ], -) -def testORTTrainerGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps, expected_loss): - return # TODO: re-enable after nondeterminism on backend is fixed. update numbers - rtol = 1e-3 - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu()) - - # Compare legacy vs experimental APIs - _test_helpers.assert_model_outputs(expected_loss, actual_loss, rtol=rtol) - - -@pytest.mark.parametrize( - "dynamic_axes", - [ - (True), - (False), - ], -) -def testORTTrainerDynamicShape(dynamic_axes): - # Common setup - device = "cuda" - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions({}) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model( - device, dynamic_axes=dynamic_axes - ) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - total_steps = 10 - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - if dynamic_axes: - # Forcing batches with different sizes to exercise dynamic shapes - data = data[: -(i + 1)] - targets = targets[: -(i + 1) * data.size(1)] - _, _ = trainer.train_step(data, targets) - - assert trainer._onnx_model is not None - - -@pytest.mark.parametrize( - "enable_onnx_contrib_ops", - [ - (True), - (False), - ], -) -def testORTTrainerInternalUseContribOps(enable_onnx_contrib_ops): - # Common setup - device = "cuda" - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions({"_internal_use": {"enable_onnx_contrib_ops": enable_onnx_contrib_ops}}) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - data, targets = batcher_fn(train_data, 0) - if not enable_onnx_contrib_ops and not pytorch_110: - with pytest.raises(Exception): # noqa: B017 - _, _ = trainer.train_step(data, targets) - else: - _, _ = trainer.train_step(data, targets) - - -@pytest.mark.parametrize( - "model_params", - [ - ( - [ - "decoder.weight", - "transformer_encoder.layers.0.linear1.bias", - "transformer_encoder.layers.0.linear2.weight", - "transformer_encoder.layers.1.self_attn.out_proj.weight", - "transformer_encoder.layers.1.self_attn.out_proj.bias", - ] - ), - ], -) -def testORTTrainerFrozenWeights(model_params): - # Common setup - device = "cuda" - total_steps = 10 - - # Setup ORTTrainer WITHOUT frozen weights - options = orttrainer.ORTTrainerOptions({}) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - _, _ = trainer.train_step(data, targets) - - # All model_params must be in the session state - assert trainer._onnx_model is not None - session_state = trainer._training_session.get_state() - assert all([param in session_state for param in model_params]) - - # Setup ORTTrainer WITH frozen weights - options = orttrainer.ORTTrainerOptions({"utils": {"frozen_weights": model_params}}) - model, _, _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - _, _ = trainer.train_step(data, targets) - - # All model_params CANNOT be in the session state - assert trainer._onnx_model is not None - session_state = trainer._training_session.get_state() - assert not all([param in session_state for param in model_params]) - - -@pytest.mark.parametrize( - "loss_scaler, optimizer_config, gradient_accumulation_steps", - [ - (None, optim.AdamConfig(), 1), - (None, optim.LambConfig(), 1), - (None, optim.SGDConfig(), 1), - (amp.DynamicLossScaler(), optim.AdamConfig(), 1), - (amp.DynamicLossScaler(), optim.LambConfig(), 5), - # (amp.DynamicLossScaler(), optim.SGDConfig(), 1), # SGD doesnt support fp16 - ], -) -def testORTTrainerStateDictWrapModelLossFn(loss_scaler, optimizer_config, gradient_accumulation_steps): - # Common setup - seed = 1 - - class LinearModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 4) - - def forward(self, y=None, x=None): - if y is not None: - return self.linear(x) + y - else: - return self.linear(x) + torch.ones(2, 4) - - model_desc = { - "inputs": [ - ("x", [2, 2]), - ( - "label", - [ - 2, - ], - ), - ], - "outputs": [("loss", [], True), ("output", [2, 4])], - } - - # Dummy data - data1 = torch.randn(2, 2) - label1 = torch.tensor([0, 1], dtype=torch.int64) - data2 = torch.randn(2, 2) - label2 = torch.tensor([0, 1], dtype=torch.int64) - - # Setup training based on test parameters - opts = { - "debug": {"deterministic_compute": True}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - } - if loss_scaler: - opts["mixed_precision"] = {"enabled": True, "loss_scaler": loss_scaler} - opts = orttrainer.ORTTrainerOptions(opts) - - # Training session 1 - torch.manual_seed(seed) - set_seed(seed) - pt_model = LinearModel() - - def loss_fn(x, label): - return F.nll_loss(F.log_softmax(x, dim=1), label) - - trainer = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts) - - # Check state_dict keys before train. Must be empty - state_dict = trainer.state_dict() - assert state_dict == {} - - # Train once and check initial state - trainer.train_step(x=data1, label=label1) - state_dict = trainer.state_dict() - assert all([weight in state_dict["model"]["full_precision"] for weight in ["linear.bias", "linear.weight"]]) - - # Initialize training session 2 from state of Training 1 - torch.manual_seed(seed) - set_seed(seed) - trainer2 = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts) - trainer2.load_state_dict(state_dict) - - # Verify state was loaded properly - _test_commons.assert_all_states_close_ort(state_dict, trainer2._load_state_dict.args[0]) - - # Perform a second step in both training session 1 and 2 and verify they match - trainer.train_step(x=data2, label=label2) - state_dict = trainer.state_dict() - trainer2.train_step(x=data2, label=label2) - state_dict2 = trainer2.state_dict() - _test_commons.assert_all_states_close_ort(state_dict, state_dict2) - - -def testORTTrainerNonPickableModel(): - # Common setup - import threading - - seed = 1 - - class UnpickableModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 4) - self._lock = threading.Lock() - - def forward(self, y=None, x=None): - with self._lock: - if y is not None: - return self.linear(x) + y - else: - return self.linear(x) + torch.ones(2, 4) - - model_desc = { - "inputs": [ - ("x", [2, 2]), - ( - "label", - [ - 2, - ], - ), - ], - "outputs": [("loss", [], True), ("output", [2, 4])], - } - - # Dummy data - data = torch.randn(2, 2) - label = torch.tensor([0, 1], dtype=torch.int64) - - # Setup training based on test parameters - opts = orttrainer.ORTTrainerOptions({"debug": {"deterministic_compute": True}}) - - # Training session - torch.manual_seed(seed) - set_seed(seed) - pt_model = UnpickableModel() - - def loss_fn(x, label): - return F.nll_loss(F.log_softmax(x, dim=1), label) - - optim_config = optim.AdamConfig() - trainer = orttrainer.ORTTrainer(pt_model, model_desc, optim_config, loss_fn=loss_fn, options=opts) - - # Train must succeed despite warning - _, _ = trainer.train_step(data, label) - - -############################################################################### -# Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############ -############################################################################### - - -@pytest.mark.parametrize("seed,device", [(1234, "cuda")]) -def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device): - # Common data - rtol = 1e-7 - total_steps = 5 - - # Setup for the experimental ORTTRainer run - torch.manual_seed(seed) - set_seed(seed) - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - # Training loop - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - _ = trainer.train_step(data, targets) - - # Setup for the legacy ORTTrainer run - torch.manual_seed(seed) - set_seed(seed) - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer( - model, my_loss, model_desc, "LambOptimizer", None, lr_desc, device, _use_deterministic_compute=True - ) - # Training loop - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - _, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) - - # Compare legacy vs experimental APIs - _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=rtol) - - -@pytest.mark.parametrize( - "seed,device", - [ - (321, "cuda"), - ], -) -def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device): - # Common data - total_steps = 128 - - # Setup experimental API - torch.manual_seed(seed) - set_seed(seed) - loss_scaler = amp.DynamicLossScaler() - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, - "debug": { - "deterministic_compute": True, - }, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - experimental_loss = [] - experimental_preds_dtype = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - exp_loss, exp_preds = trainer.train_step(data, targets) - experimental_loss.append(exp_loss.cpu()) - experimental_preds_dtype.append(exp_preds.dtype) - - # Setup legacy API - torch.manual_seed(seed) - set_seed(seed) - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - loss_scaler = Legacy_LossScaler("ort_test_input_loss_scalar", True) - legacy_trainer = Legacy_ORTTrainer( - model, - my_loss, - model_desc, - "LambOptimizer", - None, - lr_desc, - device=device, - _use_deterministic_compute=True, - use_mixed_precision=True, - loss_scaler=loss_scaler, - ) - # Training loop - legacy_loss = [] - legacy_preds_dtype = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) - legacy_loss.append(leg_loss.cpu()) - legacy_preds_dtype.append(leg_preds.dtype) - - # Compare legacy vs experimental APIs - assert experimental_preds_dtype == legacy_preds_dtype - _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer) - _test_helpers.assert_model_outputs(legacy_loss, experimental_loss) - - -@pytest.mark.parametrize( - "seed,device,gradient_accumulation_steps,total_steps", - [ - (0, "cuda", 1, 12), - (42, "cuda", 3, 12), - (123, "cuda", 7, 12), - (321, "cuda", 12, 12), - ], -) -def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps): - # Common data - torch.set_printoptions(precision=10) - - # Setup experimental API - torch.manual_seed(seed) - set_seed(seed) - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - experimental_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - exp_loss, _ = trainer.train_step(data, targets) - experimental_loss.append(exp_loss.cpu()) - - # Setup legacy API - torch.manual_seed(seed) - set_seed(seed) - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer( - model, - my_loss, - model_desc, - "LambOptimizer", - None, - lr_desc, - device=device, - _use_deterministic_compute=True, - gradient_accumulation_steps=gradient_accumulation_steps, - ) - # Training loop - legacy_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - leg_loss, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) - legacy_loss.append(leg_loss.cpu()) - - # Compare legacy vs experimental APIs - _test_helpers.assert_model_outputs(legacy_loss, experimental_loss) - - -@pytest.mark.parametrize( - "seed,device,optimizer_config,lr_scheduler, get_lr_this_step", - [ - ( - 0, - "cuda", - optim.AdamConfig, - optim.lr_scheduler.ConstantWarmupLRScheduler, - _test_commons.legacy_constant_lr_scheduler, - ), - ( - 0, - "cuda", - optim.LambConfig, - optim.lr_scheduler.ConstantWarmupLRScheduler, - _test_commons.legacy_constant_lr_scheduler, - ), - ( - 0, - "cuda", - optim.SGDConfig, - optim.lr_scheduler.ConstantWarmupLRScheduler, - _test_commons.legacy_constant_lr_scheduler, - ), - ( - 42, - "cuda", - optim.AdamConfig, - optim.lr_scheduler.LinearWarmupLRScheduler, - _test_commons.legacy_linear_lr_scheduler, - ), - ( - 42, - "cuda", - optim.LambConfig, - optim.lr_scheduler.LinearWarmupLRScheduler, - _test_commons.legacy_linear_lr_scheduler, - ), - ( - 42, - "cuda", - optim.SGDConfig, - optim.lr_scheduler.LinearWarmupLRScheduler, - _test_commons.legacy_linear_lr_scheduler, - ), - ( - 123, - "cuda", - optim.AdamConfig, - optim.lr_scheduler.CosineWarmupLRScheduler, - _test_commons.legacy_cosine_lr_scheduler, - ), - ( - 123, - "cuda", - optim.LambConfig, - optim.lr_scheduler.CosineWarmupLRScheduler, - _test_commons.legacy_cosine_lr_scheduler, - ), - ( - 123, - "cuda", - optim.SGDConfig, - optim.lr_scheduler.CosineWarmupLRScheduler, - _test_commons.legacy_cosine_lr_scheduler, - ), - ( - 321, - "cuda", - optim.AdamConfig, - optim.lr_scheduler.PolyWarmupLRScheduler, - _test_commons.legacy_poly_lr_scheduler, - ), - ( - 321, - "cuda", - optim.LambConfig, - optim.lr_scheduler.PolyWarmupLRScheduler, - _test_commons.legacy_poly_lr_scheduler, - ), - ( - 321, - "cuda", - optim.SGDConfig, - optim.lr_scheduler.PolyWarmupLRScheduler, - _test_commons.legacy_poly_lr_scheduler, - ), - ], -) -def testORTTrainerLegacyAndExperimentalLRScheduler(seed, device, optimizer_config, lr_scheduler, get_lr_this_step): - # Common data - total_steps = 10 - lr = 0.001 - warmup = 0.5 - cycles = 0.5 - power = 1.0 - lr_end = 1e-7 - torch.set_printoptions(precision=10) - - # Setup experimental API - torch.manual_seed(seed) - set_seed(seed) - if ( - lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler - or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler - ): - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) - elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) - elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) - else: - raise RuntimeError("Invalid lr_scheduler") - - options = orttrainer.ORTTrainerOptions( - {"device": {"id": device}, "debug": {"deterministic_compute": True}, "lr_scheduler": lr_scheduler} - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optimizer_config(lr=lr) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - experimental_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - exp_loss, exp_preds = trainer.train_step(data, targets) - experimental_loss.append(exp_loss.cpu()) - - # Setup legacy API - torch.manual_seed(seed) - set_seed(seed) - - if optimizer_config == optim.AdamConfig: - legacy_optimizer_config = "AdamOptimizer" - elif optimizer_config == optim.LambConfig: - legacy_optimizer_config = "LambOptimizer" - elif optimizer_config == optim.SGDConfig: - legacy_optimizer_config = "SGDOptimizer" - else: - raise RuntimeError("Invalid optimizer_config") - - if ( - get_lr_this_step == _test_commons.legacy_constant_lr_scheduler - or get_lr_this_step == _test_commons.legacy_linear_lr_scheduler - ): - get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup) - elif get_lr_this_step == _test_commons.legacy_cosine_lr_scheduler: - get_lr_this_step = partial( - get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, cycles=cycles - ) - elif get_lr_this_step == _test_commons.legacy_poly_lr_scheduler: - get_lr_this_step = partial( - get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end - ) - else: - raise RuntimeError("Invalid get_lr_this_step") - - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer( - model, - my_loss, - model_desc, - legacy_optimizer_config, - None, - lr_desc, - device=device, - _use_deterministic_compute=True, - get_lr_this_step=get_lr_this_step, - ) - # Training loop - legacy_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - leg_loss, leg_preds = legacy_trainer.train_step(data, targets) - legacy_loss.append(leg_loss.cpu()) - - # Compare legacy vs experimental APIs - _test_helpers.assert_model_outputs(legacy_loss, experimental_loss) - - -def testLossScalerLegacyAndExperimentalFullCycle(): - orttrainer.TrainStepInfo( - optimizer_config=optim.LambConfig(lr=0.001), all_finite=True, fetches=[], optimization_step=0, step=0 - ) - new_ls = amp.DynamicLossScaler() - old_ls = Legacy_LossScaler("ort_test_input_loss_scaler", True) - - # Initial state - train_step_info = orttrainer.TrainStepInfo(optim.LambConfig()) - assert_allclose(new_ls.loss_scale, old_ls.loss_scale_) - assert new_ls.up_scale_window == old_ls.up_scale_window_ - assert_allclose(new_ls.min_loss_scale, old_ls.min_loss_scale_) - assert_allclose(new_ls.max_loss_scale, old_ls.max_loss_scale_) - - # Performing 9*2000 updates to cover all branches of LossScaler.update(train_step_info.all_finite=True) - for _cycles in range(1, 10): - # 1999 updates without overflow produces 1999 stable steps - for _i in range(1, 2000): - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - # 2000th update without overflow doubles the loss and zero stable steps until max_loss_scale is reached - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - # After 8 cycles, loss scale should be float(1 << 16)*(2**8) - assert_allclose(new_loss_scale, old_loss_scale) - - # After 9 cycles, loss scale reaches max_loss_scale and it is not doubled from that point on - for _count in range(1, 2050): - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - # Setting train_step_info.all_finite = False to test down scaling - train_step_info.all_finite = False - - # Performing 24 updates to half the loss scale each time - for _count in range(1, 25): - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - # After 24 updates with gradient overflow, loss scale is 1.0 - assert_allclose(new_loss_scale, old_loss_scale) - - # After 25 updates, min_loss_scale is reached and loss scale is not halfed from that point on - for _count in range(1, 5): - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - -def testLossScalerLegacyAndExperimentalRandomAllFinite(): - new_ls = amp.DynamicLossScaler() - old_ls = Legacy_LossScaler("ort_test_input_loss_scaler", True) - - # Initial state - train_step_info = orttrainer.TrainStepInfo(optim.LambConfig()) - assert_allclose(new_ls.loss_scale, old_ls.loss_scale_) - assert new_ls.up_scale_window == old_ls.up_scale_window_ - assert_allclose(new_ls.min_loss_scale, old_ls.min_loss_scale_) - assert_allclose(new_ls.max_loss_scale, old_ls.max_loss_scale_) - - import random - - out = [] - for _ in range(1, 64): - train_step_info.all_finite = bool(random.getrandbits(1)) - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - out.append(new_loss_scale) - assert new_loss_scale > 1e-7 - - -def testORTTrainerRunSymbolicShapeInfer(): - # Common data - seed = 0 - total_steps = 12 - device = "cuda" - torch.set_printoptions(precision=10) - - # Setup without symbolic shape inference - torch.manual_seed(seed) - set_seed(seed) - options = orttrainer.ORTTrainerOptions({"device": {"id": device}, "debug": {"deterministic_compute": True}}) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - expected_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - expected_loss.append(loss.cpu()) - - # Setup with symbolic shape inference - torch.manual_seed(seed) - set_seed(seed) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - options.utils.run_symbolic_shape_infer = True - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - new_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - new_loss.append(loss.cpu()) - - # Setup with symbolic shape inference in legacy API - torch.manual_seed(seed) - set_seed(seed) - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer( - model, - my_loss, - model_desc, - "LambOptimizer", - None, - lr_desc, - device=device, - run_symbolic_shape_infer=True, - _use_deterministic_compute=True, - ) - # Training loop - legacy_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) - legacy_loss.append(loss.cpu()) - - # Compare losses - _test_helpers.assert_model_outputs(new_loss, expected_loss) - _test_helpers.assert_model_outputs(legacy_loss, expected_loss) - - -@pytest.mark.parametrize( - "test_input", - [ - ( - { - "distributed": {"enable_adasum": True}, - } - ) - ], -) -def testORTTrainerOptionsEnabledAdasumFlag(test_input): - """Test the enabled_adasum flag values when set enabled""" - - actual_values = orttrainer_options.ORTTrainerOptions(test_input) - assert actual_values.distributed.enable_adasum is True - - -@pytest.mark.parametrize( - "test_input", - [ - ( - { - "distributed": {"enable_adasum": False}, - } - ) - ], -) -def testORTTrainerOptionsDisabledAdasumFlag(test_input): - """Test the enabled_adasum flag values when set disabled""" - - actual_values = orttrainer_options.ORTTrainerOptions(test_input) - assert actual_values.distributed.enable_adasum is False - - -def testORTTrainerUnusedInput(): - class UnusedInputModel(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.mean(x) - - model = UnusedInputModel() - model_desc = {"inputs": [("x", [1]), ("y", [1])], "outputs": [("loss", [], True)]} - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config) - - # Run just one step to make sure there are no iobinding errors for the unused input. - try: - trainer.train_step(torch.FloatTensor([1.0]), torch.FloatTensor([1.0])) - except RuntimeError: - pytest.fail("RuntimeError doing train_step with an unused input.") - - -@pytest.mark.parametrize( - "debug_files", - [ - { - "model_after_graph_transforms_path": "transformed.onnx", - "model_with_gradient_graph_path": "transformed_grad.onnx", - "model_with_training_graph_path": "training.onnx", - "model_with_training_graph_after_optimization_path": "training_optimized.onnx", - }, - {"model_after_graph_transforms_path": "transformed.onnx", "model_with_training_graph_path": ""}, - ], -) -def testTrainingGraphExport(debug_files): - device = "cuda" - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - - with tempfile.TemporaryDirectory() as tempdir: - debug_paths = {} - for k, v in debug_files.items(): - debug_paths[k] = os.path.join(tempdir, v) - opts = orttrainer.ORTTrainerOptions({"device": {"id": device}, "debug": {"graph_save_paths": debug_paths}}) - optim_config = optim.AdamConfig() - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - data, targets = batcher_fn(train_data, 0) - trainer.train_step(data, targets) - for k, v in debug_files.items(): - path = debug_paths[k] - if len(v) > 0: - assert os.path.isfile(path) - saved_graph = onnx.load(path).graph - if k == "model_with_training_graph_path": - assert any("AdamOptimizer" in n.op_type for n in saved_graph.node) - elif k == "model_with_gradient_graph_path": - assert any("Grad" in n.name for n in saved_graph.node) - elif k == "model_after_graph_transforms_path": - assert any("LayerNormalization" in n.op_type for n in saved_graph.node) - elif k == "model_with_training_graph_after_optimization_path": - assert any("FusedMatMul" in n.op_type for n in saved_graph.node) - # remove saved file - os.remove(path) - else: - assert not os.path.isfile(path) - - -def _adam_max_norm_clip_data(): - device_capability_major = torch.cuda.get_device_capability()[0] - if device_capability_major == 7: # V100 for Dev machine - return [ - ( - 0, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.592951, - 10.067989, - 9.619152, - 9.245731, - 8.881137, - 8.578644, - 8.280573, - 8.063023, - 7.797933, - 7.486215, - 7.233806, - 7.011791, - ], - 14: [ - 10.584141, - 10.068119, - 9.581743, - 9.191472, - 8.880169, - 8.5352, - 8.311425, - 8.061202, - 7.773032, - 7.523009, - 7.258711, - 7.02805, - ], - }, - ), - ( - 0, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.592951, - 10.068722, - 9.620503, - 9.247791, - 8.883972, - 8.582286, - 8.285027, - 8.068308, - 7.803638, - 7.492318, - 7.240352, - 7.018665, - ], - 14: [ - 10.584141, - 10.068845, - 9.583107, - 9.193537, - 8.882966, - 8.538839, - 8.315872, - 8.066408, - 7.778978, - 7.529708, - 7.265849, - 7.035439, - ], - }, - ), - ( - 42, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.647908, - 10.144501, - 9.672352, - 9.306980, - 8.956026, - 8.602655, - 8.351079, - 8.088144, - 7.867220, - 7.564082, - 7.289846, - 7.073726, - ], - 14: [ - 10.697515, - 10.229034, - 9.765422, - 9.428294, - 9.080612, - 8.715208, - 8.459574, - 8.169073, - 7.940211, - 7.654147, - 7.390446, - 7.166227, - ], - }, - ), - ( - 42, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.647908, - 10.145191, - 9.673690, - 9.309031, - 8.959020, - 8.606632, - 8.355836, - 8.093478, - 7.873327, - 7.570731, - 7.296772, - 7.0809422, - ], - 14: [ - 10.697515, - 10.22967, - 9.766556, - 9.430037, - 9.083106, - 8.718601, - 8.463726, - 8.17396, - 7.945755, - 7.660188, - 7.396963, - 7.172944, - ], - }, - ), - ] - elif device_capability_major == 5: # M60 for CI machines (Python Packaging Pipeline) - return [ - ( - 0, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.618382, - 10.08292, - 9.603334, - 9.258133, - 8.917768, - 8.591574, - 8.318401, - 8.042292, - 7.783608, - 7.50226, - 7.236041, - 7.035602, - ], - 14: [ - 10.618382, - 10.08292, - 9.603334, - 9.258133, - 8.917768, - 8.591574, - 8.318401, - 8.042292, - 7.783608, - 7.50226, - 7.236041, - 7.035602, - ], - }, - ), - ( - 0, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.618382, - 10.083632, - 9.604639, - 9.260109, - 8.920504, - 8.595082, - 8.322799, - 8.047493, - 7.78929, - 7.508382, - 7.242587, - 7.042367, - ], - 14: [ - 10.618382, - 10.083632, - 9.604639, - 9.260109, - 8.920504, - 8.595082, - 8.322799, - 8.047493, - 7.78929, - 7.508382, - 7.242587, - 7.042367, - ], - }, - ), - ( - 42, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.68639, - 10.102986, - 9.647681, - 9.293091, - 8.958928, - 8.625297, - 8.351107, - 8.079577, - 7.840723, - 7.543044, - 7.284141, - 7.072688, - ], - 14: [ - 10.68639, - 10.102986, - 9.647681, - 9.293091, - 8.958928, - 8.625297, - 8.351107, - 8.079577, - 7.840723, - 7.543044, - 7.284141, - 7.072688, - ], - }, - ), - ( - 42, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.68639, - 10.103672, - 9.649025, - 9.295167, - 8.961777, - 8.629059, - 8.355571, - 8.084871, - 7.846589, - 7.549438, - 7.290722, - 7.079446, - ], - 14: [ - 10.697515, - 10.22967, - 9.766556, - 9.430037, - 9.083106, - 8.718601, - 8.463726, - 8.17396, - 7.945755, - 7.660188, - 7.396963, - 7.172944, - ], - }, - ), - ] - - -@pytest.mark.parametrize( - "seed,device,max_norm_clip,gradient_accumulation_steps,total_steps,expected_loss", _adam_max_norm_clip_data() -) -def testORTTrainerAdamMaxNormClip(seed, device, max_norm_clip, gradient_accumulation_steps, total_steps, expected_loss): - rtol = 1e-5 - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.AdamConfig(lr=0.001, max_norm_clip=max_norm_clip) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu().item()) - - # Compare legacy vs experimental APIs - assert trainer._onnx_model is not None - opset = get_model_opset(trainer._onnx_model) - _test_helpers.assert_model_outputs(expected_loss[opset], actual_loss, rtol=rtol) - - -def _lamb_max_norm_clip_data(): - device_capability_major = torch.cuda.get_device_capability()[0] - if device_capability_major == 7: # V100 for Dev machine - return [ - ( - 0, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.592951, - 10.487728, - 10.422251, - 10.350913, - 10.244248, - 10.213003, - 10.129222, - 10.095112, - 10.035983, - 9.974586, - 9.909771, - 9.874278, - ], - 14: [ - 10.584141, - 10.497192, - 10.389251, - 10.286045, - 10.231354, - 10.17018, - 10.066779, - 10.048138, - 9.958029, - 9.8908, - 9.82965, - 9.755484, - ], - }, - ), - ( - 0, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.592951, - 10.452503, - 10.349832, - 10.245314, - 10.106587, - 10.046009, - 9.934781, - 9.875164, - 9.792067, - 9.704592, - 9.617104, - 9.563070, - ], - 14: [ - 10.584141, - 10.461154, - 10.315399, - 10.178979, - 10.092329, - 9.999928, - 9.869949, - 9.824564, - 9.707565, - 9.61643, - 9.532847, - 9.439593, - ], - }, - ), - ( - 42, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.647908, - 10.566276, - 10.476154, - 10.406275, - 10.311079, - 10.240053, - 10.196469, - 10.113955, - 10.117376, - 10.013077, - 9.930301, - 9.893368, - ], - 14: [ - 10.697515, - 10.631279, - 10.528757, - 10.496689, - 10.411219, - 10.322109, - 10.297314, - 10.215549, - 10.149698, - 10.087336, - 10.010884, - 9.934544, - ], - }, - ), - ( - 42, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.647908, - 10.531957, - 10.405246, - 10.302971, - 10.176583, - 10.075583, - 10.005772, - 9.897825, - 9.875748, - 9.748932, - 9.642885, - 9.586762, - ], - 14: [ - 10.697515, - 10.596729, - 10.457815, - 10.393475, - 10.277581, - 10.158909, - 10.108126, - 10.000326, - 9.912526, - 9.826057, - 9.727899, - 9.633768, - ], - }, - ), - ] - elif device_capability_major == 5: # M60 for CI machines (Python Packaging Pipeline) - return [ - ( - 0, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.618382, - 10.50222, - 10.403347, - 10.35298, - 10.288447, - 10.237399, - 10.184225, - 10.089048, - 10.008952, - 9.972644, - 9.897674, - 9.84524, - ], - 14: [0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], - }, - ), - ( - 0, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.618382, - 10.466732, - 10.330871, - 10.24715, - 10.150972, - 10.069127, - 9.98974, - 9.870169, - 9.763693, - 9.704323, - 9.605957, - 9.533117, - ], - 14: [1, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], - }, - ), - ( - 42, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.68639, - 10.511692, - 10.447308, - 10.405255, - 10.334866, - 10.261473, - 10.169422, - 10.107138, - 10.069889, - 9.97798, - 9.928105, - 9.896435, - ], - 14: [2, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], - }, - ), - ( - 42, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.68639, - 10.477489, - 10.376671, - 10.301725, - 10.200718, - 10.098477, - 9.97995, - 9.890104, - 9.828899, - 9.713555, - 9.639567, - 9.589856, - ], - 14: [3, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], - }, - ), - ] - - -@pytest.mark.parametrize( - "seed,device,max_norm_clip, gradient_accumulation_steps,total_steps,expected_loss", _lamb_max_norm_clip_data() -) -def testORTTrainerLambMaxNormClip(seed, device, max_norm_clip, gradient_accumulation_steps, total_steps, expected_loss): - rtol = 1e-3 - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001, max_norm_clip=max_norm_clip) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu().item()) - - # Compare legacy vs experimental APIs - opset = get_model_opset(trainer._onnx_model) - _test_helpers.assert_model_outputs(expected_loss[opset], actual_loss, rtol=rtol) diff --git a/orttraining/orttraining/test/python/orttraining_test_transformers.py b/orttraining/orttraining/test/python/orttraining_test_transformers.py deleted file mode 100644 index dbaf4a293c466..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_transformers.py +++ /dev/null @@ -1,480 +0,0 @@ -import random -import unittest - -import numpy as np -import torch -from numpy.testing import assert_allclose -from orttraining_test_data_loader import BatchArgsOption, ids_tensor -from orttraining_test_utils import get_lr, run_test -from transformers import BertConfig, BertForPreTraining - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer # noqa: F401 - - -class BertModelTest(unittest.TestCase): - class BertModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=True, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=5, - num_attention_heads=4, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - scope=None, - device="cpu", - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.scope = scope - self.device = device - - # 1. superset of bert input/output descs - # see BertPreTrainedModel doc - self.input_ids_desc = IODescription( - "input_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.vocab_size - ) - self.attention_mask_desc = IODescription( - "attention_mask", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2 - ) - self.token_type_ids_desc = IODescription( - "token_type_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2 - ) - self.position_ids_desc = IODescription( - "position_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.max_position_embeddings - ) - self.head_mask_desc = IODescription( - "head_mask", [self.num_hidden_layers, self.num_attention_heads], torch.int64, num_classes=2 - ) - self.inputs_embeds_desc = IODescription( - "inputs_embeds", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32 - ) - - self.encoder_hidden_states_desc = IODescription( - "encoder_hidden_states", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32 - ) - self.encoder_attention_mask_desc = IODescription( - "encoder_attention_mask", ["batch", "max_seq_len_in_batch"], torch.float32 - ) - - # see BertForPreTraining doc - self.masked_lm_labels_desc = IODescription( - "masked_lm_labels", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.vocab_size - ) - self.next_sentence_label_desc = IODescription( - "next_sentence_label", - [ - "batch", - ], - torch.int64, - num_classes=2, - ) - - # outputs - self.loss_desc = IODescription( - "loss", - [ - 1, - ], - torch.float32, - ) - self.prediction_scores_desc = IODescription( - "prediction_scores", ["batch", "max_seq_len_in_batch", self.vocab_size], torch.float32 - ) - - self.seq_relationship_scores_desc = IODescription( - "seq_relationship_scores", ["batch", 2], torch.float32 - ) # IODescription('seq_relationship_scores', ['batch', 'max_seq_len_in_batch', 2], torch.float32) - self.hidden_states_desc = IODescription( - "hidden_states", - [self.num_hidden_layers, "batch", "max_seq_len_in_batch", self.hidden_size], - torch.float32, - ) - self.attentions_desc = IODescription( - "attentions", - [ - self.num_hidden_layers, - "batch", - self.num_attention_heads, - "max_seq_len_in_batch", - "max_seq_len_in_batch", - ], - torch.float32, - ) - self.last_hidden_state_desc = IODescription( - "last_hidden_state", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32 - ) - self.pooler_output_desc = IODescription("pooler_output", ["batch", self.hidden_size], torch.float32) - - def BertForPreTraining_descs(self): - return ModelDescription( - [ - self.input_ids_desc, - self.attention_mask_desc, - self.token_type_ids_desc, - self.masked_lm_labels_desc, - self.next_sentence_label_desc, - ], - # returns loss_desc if both masked_lm_labels_desc, next_sentence_label are provided - # hidden_states_desc, attentions_desc shall be included according to config.output_attentions, config.output_hidden_states - [ - self.loss_desc, - self.prediction_scores_desc, - self.seq_relationship_scores_desc, - # hidden_states_desc, attentions_desc - ], - ) - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).to(self.device) - - input_mask = None - if self.use_input_mask: - input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2).to(self.device) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size).to(self.device) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size).to(self.device) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels).to(self.device) - choice_labels = ids_tensor([self.batch_size], self.num_choices).to(self.device) - - config = BertConfig( - vocab_size=self.vocab_size, - vocab_size_or_config_json_file=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - ) - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def create_and_check_bert_for_pretraining( - self, - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - option_use_internal_get_lr_this_step=[True], # noqa: B006 - option_use_internal_loss_scaler=[True], # noqa: B006 - ): - seed = 42 - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - onnxruntime.set_seed(seed) - - model = BertForPreTraining(config=config) - model.eval() - loss, prediction_scores, seq_relationship_score = model( - input_ids, - attention_mask=input_mask, - token_type_ids=token_type_ids, - masked_lm_labels=token_labels, - next_sentence_label=sequence_labels, - ) - model_desc = ModelDescription( - [ - self.input_ids_desc, - self.attention_mask_desc, - self.token_type_ids_desc, - self.masked_lm_labels_desc, - self.next_sentence_label_desc, - ], - [self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc], - ) - - from collections import namedtuple - - MyArgs = namedtuple( - "MyArgs", "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len" - ) - - dataset_len = 100 - epochs = 8 - max_steps = epochs * dataset_len - args = MyArgs( - local_rank=0, - world_size=1, - max_steps=max_steps, - learning_rate=0.00001, - warmup_proportion=0.01, - batch_size=13, - seq_len=7, - ) - - def get_lr_this_step(global_step): - return get_lr(args, global_step) - - loss_scaler = LossScaler("loss_scale_input_name", True, up_scale_window=2000) - - for fp16 in option_fp16: - for allreduce_post_accumulation in option_allreduce_post_accumulation: - for gradient_accumulation_steps in option_gradient_accumulation_steps: - for use_internal_get_lr_this_step in option_use_internal_get_lr_this_step: - for use_internal_loss_scaler in option_use_internal_loss_scaler: - for split_batch in option_split_batch: - print("gradient_accumulation_steps:", gradient_accumulation_steps) - print("split_batch:", split_batch) - - seed = 42 - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - onnxruntime.set_seed(seed) - - ( - old_api_loss_ort, - old_api_prediction_scores_ort, - old_api_seq_relationship_score_ort, - ) = run_test( - model, - model_desc, - self.device, - args, - gradient_accumulation_steps, - fp16, - allreduce_post_accumulation, - get_lr_this_step, - use_internal_get_lr_this_step, - loss_scaler, - use_internal_loss_scaler, - split_batch, - dataset_len, - epochs, - use_new_api=False, - ) - - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - onnxruntime.set_seed(seed) - if use_internal_get_lr_this_step and use_internal_loss_scaler: - ( - new_api_loss_ort, - new_api_prediction_scores_ort, - new_api_seq_relationship_score_ort, - ) = run_test( - model, - model_desc, - self.device, - args, - gradient_accumulation_steps, - fp16, - allreduce_post_accumulation, - get_lr_this_step, - use_internal_get_lr_this_step, - loss_scaler, - use_internal_loss_scaler, - split_batch, - dataset_len, - epochs, - use_new_api=True, - ) - - assert_allclose(old_api_loss_ort, new_api_loss_ort) - assert_allclose(old_api_prediction_scores_ort, new_api_prediction_scores_ort) - assert_allclose( - old_api_seq_relationship_score_ort, new_api_seq_relationship_score_ort - ) - - def setUp(self): - self.model_tester = BertModelTest.BertModelTester(self) - - def test_for_pretraining_mixed_precision(self): - # It would be better to test both with/without mixed precision and allreduce_post_accumulation. - # However, stress test of all the 4 cases is not stable at least on the test machine. - # There we only test mixed precision and allreduce_post_accumulation because it is the most useful use cases. - option_fp16 = [True] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1] - option_split_batch = [BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_mixed_precision_with_gradient_accumulation(self): - # It would be better to test both with/without mixed precision and allreduce_post_accumulation. - # However, stress test of all the 4 cases is not stable at least on the test machine. - # There we only test mixed precision and allreduce_post_accumulation because it is the most useful use cases. - option_fp16 = [True] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [8] - option_split_batch = [BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_all(self): - # This test is not stable because it create and run ORTSession multiple times. - # It occasionally gets seg fault at ~MemoryPattern() - # when releasing patterns_. In order not to block PR merging CI test, - # this test is broke into following individual tests. - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1, 8] - option_split_batch = [BatchArgsOption.List, BatchArgsOption.Dict, BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_list_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1] - option_split_batch = [BatchArgsOption.List] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_dict_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1] - option_split_batch = [BatchArgsOption.Dict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_list_and_dict_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1] - option_split_batch = [BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_grad_accumulation_list_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [8] - option_split_batch = [BatchArgsOption.List] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_grad_accumulation_dict_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [8] - option_split_batch = [BatchArgsOption.Dict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_grad_accumulation_list_and_dict_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [8] - option_split_batch = [BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_test_utils.py b/orttraining/orttraining/test/python/orttraining_test_utils.py deleted file mode 100644 index 527cfb8a0ba7d..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_utils.py +++ /dev/null @@ -1,246 +0,0 @@ -import math - -import torch -from orttraining_test_data_loader import BatchArgsOption, create_ort_test_dataloader, split_batch - -from onnxruntime.capi.ort_trainer import IODescription, ORTTrainer -from onnxruntime.training import amp, optim, orttrainer -from onnxruntime.training.optim import _LRScheduler - - -def warmup_cosine(x, warmup=0.002): - if x < warmup: - return x / warmup - return 0.5 * (1.0 + torch.cos(math.pi * x)) - - -def warmup_constant(x, warmup=0.002): - if x < warmup: - return x / warmup - return 1.0 - - -def warmup_linear(x, warmup=0.002): - if x < warmup: - return x / warmup - return max((x - 1.0) / (warmup - 1.0), 0.0) - - -def warmup_poly(x, warmup=0.002, degree=0.5): - if x < warmup: - return x / warmup - return (1.0 - x) ** degree - - -SCHEDULES = { - "warmup_cosine": warmup_cosine, - "warmup_constant": warmup_constant, - "warmup_linear": warmup_linear, - "warmup_poly": warmup_poly, -} - - -def get_lr(args, training_steps, schedule="warmup_poly"): - if args.max_steps == -1: - return args.learning_rate - - schedule_fct = SCHEDULES[schedule] - return args.learning_rate * schedule_fct(training_steps / args.max_steps, args.warmup_proportion) - - -def map_optimizer_attributes(name): - no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] - no_decay = any(no_decay_key in name for no_decay_key in no_decay_keys) - if no_decay: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} - else: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} - - -class WrapLRScheduler(_LRScheduler): - def __init__(self, get_lr_this_step): - super().__init__() - self.get_lr_this_step = get_lr_this_step - - def get_lr(self, train_step_info): - return [self.get_lr_this_step(train_step_info.optimization_step)] - - -def run_test( - model, - model_desc, - device, - args, - gradient_accumulation_steps, - fp16, - allreduce_post_accumulation, - get_lr_this_step, - use_internal_get_lr_this_step, - loss_scaler, - use_internal_loss_scaler, - batch_args_option, - dataset_len, - epochs, - use_new_api, -): - dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device) - - if use_new_api: - assert use_internal_loss_scaler, "new api should always use internal loss scaler" - - new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step) - - new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None - options = orttrainer.ORTTrainerOptions( - { - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "device": {"id": device}, - "mixed_precision": {"enabled": fp16, "loss_scaler": new_api_loss_scaler}, - "debug": { - "deterministic_compute": True, - }, - "utils": {"grad_norm_clip": True}, - "distributed": {"allreduce_post_accumulation": True}, - "lr_scheduler": new_api_lr_scheduler, - } - ) - - param_optimizer = list(model.named_parameters()) - params = [ - { - "params": [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n], - "alpha": 0.9, - "beta": 0.999, - "lambda": 0.0, - "epsilon": 1e-6, - }, - { - "params": [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)], - "alpha": 0.9, - "beta": 0.999, - "lambda": 0.0, - "epsilon": 1e-6, - }, - ] - - vocab_size = 99 - new_model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", "max_seq_len_in_batch"], - ), - ( - "token_type_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "masked_lm_labels", - ["batch", "max_seq_len_in_batch"], - ), - ( - "next_sentence_label", - [ - "batch", - ], - ), - ], - "outputs": [ - ( - "loss", - [ - 1, - ], - True, - ), - ("prediction_scores", ["batch", "max_seq_len_in_batch", vocab_size]), - ("seq_relationship_scores", ["batch", 2]), - ], - } - - optim_config = optim.LambConfig(params=params, lr=2e-5) - model = orttrainer.ORTTrainer(model, new_model_desc, optim_config, options=options) - print("running with new frontend API") - else: - model = ORTTrainer( - model, - None, - model_desc, - "LambOptimizer", - map_optimizer_attributes=map_optimizer_attributes, - learning_rate_description=IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ), - device=device, - _enable_internal_postprocess=True, - gradient_accumulation_steps=gradient_accumulation_steps, - # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6 - world_rank=args.local_rank, - world_size=args.world_size, - use_mixed_precision=fp16, - allreduce_post_accumulation=allreduce_post_accumulation, - get_lr_this_step=get_lr_this_step if use_internal_get_lr_this_step else None, - loss_scaler=loss_scaler if use_internal_loss_scaler else None, - _opset_version=14, - _use_deterministic_compute=True, - ) - print("running with old frontend API") - - # training loop - eval_batch = None - if not use_new_api: - model.train() - for _epoch in range(epochs): - for step, batch in enumerate(dataloader): - if eval_batch is None: - eval_batch = batch - - if not use_internal_get_lr_this_step: - lr = get_lr_this_step(step) - learning_rate = torch.tensor([lr]) - - if not use_internal_loss_scaler and fp16: - loss_scale = torch.tensor([loss_scaler.loss_scale_]) - - if batch_args_option == BatchArgsOption.List: - if not use_internal_get_lr_this_step: - batch = [*batch, learning_rate] # noqa: PLW2901 - if not use_internal_loss_scaler and fp16: - batch = [*batch, loss_scale] # noqa: PLW2901 - outputs = model.train_step(*batch) - elif batch_args_option == BatchArgsOption.Dict: - args, kwargs = split_batch(batch, model_desc.inputs_, 0) - if not use_internal_get_lr_this_step: - kwargs["Learning_Rate"] = learning_rate - if not use_internal_loss_scaler and fp16: - kwargs[model.loss_scale_input_name] = loss_scale - outputs = model.train_step(*args, **kwargs) - else: - args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs - args, kwargs = split_batch(batch, model_desc.inputs_, args_count) - if not use_internal_get_lr_this_step: - kwargs["Learning_Rate"] = learning_rate - if not use_internal_loss_scaler and fp16: - kwargs[model.loss_scale_input_name] = loss_scale - outputs = model.train_step(*args, **kwargs) - - # eval - if batch_args_option == BatchArgsOption.List: - outputs = model.eval_step(*batch) - elif batch_args_option == BatchArgsOption.Dict: - args, kwargs = split_batch(batch, model_desc.inputs_, 0) - outputs = model.eval_step(*args, **kwargs) - else: - args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs - args, kwargs = split_batch(batch, model_desc.inputs_, args_count) - outputs = model.eval_step(*args, **kwargs) - - return (output.cpu().numpy() for output in outputs) diff --git a/orttraining/orttraining/test/python/orttraining_transformer_trainer.py b/orttraining/orttraining/test/python/orttraining_transformer_trainer.py deleted file mode 100644 index bce726871bacf..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_transformer_trainer.py +++ /dev/null @@ -1,357 +0,0 @@ -# adapted from Trainer.py of huggingface transformers - -import json -import logging -import os -import random -from typing import Callable, Dict, List, NamedTuple, Optional - -import numpy as np -import torch -from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import Dataset -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import SequentialSampler -from tqdm import tqdm, trange -from transformers.data.data_collator import DefaultDataCollator -from transformers.modeling_utils import PreTrainedModel -from transformers.training_args import TrainingArguments - -import onnxruntime -from onnxruntime.training import amp, optim, orttrainer - -try: - from torch.utils.tensorboard import SummaryWriter - - _has_tensorboard = True -except ImportError: - try: - from tensorboardX import SummaryWriter # noqa: F401 - - _has_tensorboard = True - except ImportError: - _has_tensorboard = False - - -def is_tensorboard_available(): - return _has_tensorboard - - -logger = logging.getLogger(__name__) - - -def set_seed(seed: int): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - onnxruntime.set_seed(seed) - - -class EvalPrediction(NamedTuple): - predictions: np.ndarray - label_ids: np.ndarray - - -class PredictionOutput(NamedTuple): - predictions: np.ndarray - label_ids: Optional[np.ndarray] - metrics: Optional[Dict[str, float]] - - -class TrainOutput(NamedTuple): - global_step: int - training_loss: float - - -def get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps, base_lr): - def lr_lambda_linear(current_step): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) - - def lambda_lr_get_lr(current_global_step): - # LambdaLR increment self.last_epoch at evert sept() - return base_lr * lr_lambda_linear(current_global_step) - - return lambda_lr_get_lr - - -class ORTTransformerTrainer: - """ """ - - model: PreTrainedModel - args: TrainingArguments - train_dataset: Dataset - eval_dataset: Dataset - compute_metrics: Callable[[EvalPrediction], Dict] - - def __init__( - self, - model: PreTrainedModel, - model_desc: dict, - args: TrainingArguments, - train_dataset: Dataset, - eval_dataset: Dataset, - compute_metrics: Callable[[EvalPrediction], Dict], - world_size: Optional[int] = 1, - ): - """ """ - - self.model = model - self.model_desc = model_desc - self.args = args - self.world_size = world_size - self.data_collator = DefaultDataCollator() - self.train_dataset = train_dataset - self.eval_dataset = eval_dataset - self.compute_metrics = compute_metrics - set_seed(self.args.seed) - # Create output directory if needed - if self.args.local_rank in [-1, 0]: - os.makedirs(self.args.output_dir, exist_ok=True) - - def get_train_dataloader(self) -> DataLoader: - if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") - train_sampler = ( - SequentialSampler(self.train_dataset) - if self.args.local_rank == -1 - else DistributedSampler(self.train_dataset) - ) - return DataLoader( - self.train_dataset, - batch_size=self.args.train_batch_size, - sampler=train_sampler, - collate_fn=self.data_collator.collate_batch, - ) - - def get_eval_dataloader(self) -> DataLoader: - return DataLoader( - self.eval_dataset, - batch_size=self.args.eval_batch_size, - shuffle=False, - collate_fn=self.data_collator.collate_batch, - ) - - def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: - # We use the same batch_size as for eval. - return DataLoader( - test_dataset, - batch_size=self.args.eval_batch_size, - shuffle=False, - collate_fn=self.data_collator.collate_batch, - ) - - def train(self): - """ - Main training entry point. - """ - train_dataloader = self.get_train_dataloader() - - if self.args.max_steps > 0: - t_total = self.args.max_steps - num_train_epochs = ( - self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1 - ) - else: - t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) - num_train_epochs = self.args.num_train_epochs - - lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(t_total, self.args.warmup_steps / float(t_total)) - - loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None - device = self.args.device.type - - device = f"{device}:{self.args.device.index}" if self.args.device.index else f"{device}:0" - options = orttrainer.ORTTrainerOptions( - { - "batch": {"gradient_accumulation_steps": self.args.gradient_accumulation_steps}, - "device": {"id": device}, - "mixed_precision": {"enabled": self.args.fp16, "loss_scaler": loss_scaler}, - "debug": { - "deterministic_compute": True, - }, - "utils": {"grad_norm_clip": False}, - "distributed": { - # we are running single node multi gpu test. thus world_rank = local_rank - # and world_size = self.args.n_gpu - "world_rank": max(0, self.args.local_rank), - "world_size": int(self.world_size), - "local_rank": max(0, self.args.local_rank), - "allreduce_post_accumulation": True, - }, - "lr_scheduler": lr_scheduler, - } - ) - - param_optimizer = list(self.model.named_parameters()) - params = [ - { - "params": [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n], - "weight_decay_mode": 1, - }, - { - "params": [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)], - "weight_decay_mode": 1, - }, - ] - - optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) - self.model = orttrainer.ORTTrainer(self.model, self.model_desc, optim_config, options=options) - - # Train! - logger.info("***** Running training *****") - logger.info(" Num examples = %d", len(train_dataloader.dataset)) - logger.info(" Num Epochs = %d", num_train_epochs) - logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size) - logger.info( - " Total train batch size (w. parallel, distributed & accumulation) = %d", - self.args.train_batch_size - * self.args.gradient_accumulation_steps - * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1), - ) - logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) - logger.info(" Total optimization steps = %d", t_total) - - global_step = 0 - epochs_trained = 0 - steps_trained_in_current_epoch = 0 - - tr_loss = 0.0 - logging_loss = 0.0 - train_iterator = trange( - epochs_trained, - int(num_train_epochs), - desc="Epoch", - disable=self.args.local_rank not in [-1, 0], - ) - - for _epoch in train_iterator: - epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0]) - for step, inputs in enumerate(epoch_iterator): - # Skip past any already trained steps if resuming training - if steps_trained_in_current_epoch > 0: - steps_trained_in_current_epoch -= 1 - continue - - tr_loss += self._training_step(self.model, inputs) - - if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( - len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator) - ): - global_step += 1 - - if self.args.local_rank in [-1, 0]: - if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or ( - global_step == 1 and self.args.logging_first_step - ): - logs = {} - if self.args.evaluate_during_training: - results = self.evaluate() - for key, value in results.items(): - eval_key = f"eval_{key}" - logs[eval_key] = value - - loss_scalar = (tr_loss - logging_loss) / self.args.logging_steps - - logs["loss"] = loss_scalar - logging_loss = tr_loss - - epoch_iterator.write(json.dumps({**logs, **{"step": global_step}})) - - if self.args.max_steps > 0 and global_step > self.args.max_steps: - epoch_iterator.close() - break - if self.args.max_steps > 0 and global_step > self.args.max_steps: - train_iterator.close() - break - - logger.info("\n\nTraining completed. \n\n") - return TrainOutput(global_step, tr_loss / global_step) - - def _training_step(self, model, inputs: Dict[str, torch.Tensor]) -> float: - for k, v in inputs.items(): - inputs[k] = v.to(self.args.device) - - outputs = model.train_step(**inputs) - loss = outputs[0] # model outputs are always tuple in transformers (see doc) - - return loss.item() - - def save_model(self, output_dir: Optional[str] = None): - output_dir = output_dir if output_dir is not None else self.args.output_dir - os.makedirs(output_dir, exist_ok=True) - self.model.save_as_onnx(os.path.join(output_dir, "transformer.onnx")) - - def evaluate(self) -> Dict[str, float]: - """ - Run evaluation and return metrics. - - Returns: - A dict containing: - - the eval loss - - the potential metrics computed from the predictions - """ - eval_dataloader = self.get_eval_dataloader() - - output = self._prediction_loop(eval_dataloader, description="Evaluation") - return output.metrics - - def predict(self, test_dataset: Dataset) -> PredictionOutput: - """ - Run prediction and return predictions and potential metrics. - - Depending on the dataset and your use case, your test dataset may contain labels. - In that case, this method will also return metrics, like in evaluate(). - """ - test_dataloader = self.get_test_dataloader(test_dataset) - return self._prediction_loop(test_dataloader, description="Prediction") - - def _prediction_loop(self, dataloader: DataLoader, description: str) -> PredictionOutput: - """ - Prediction/evaluation loop, shared by `evaluate()` and `predict()`. - - Works both with or without labels. - """ - - logger.info("***** Running %s *****", description) - logger.info(" Num examples = %d", len(dataloader.dataset)) - logger.info(" Batch size = %d", dataloader.batch_size) - eval_losses: List[float] = [] - preds: np.ndarray = None - label_ids: np.ndarray = None - - for inputs in tqdm(dataloader, desc=description): - has_labels = any(inputs.get(k) is not None for k in ["labels", "masked_lm_labels"]) - - for k, v in inputs.items(): - inputs[k] = v.to(self.args.device) - - with torch.no_grad(): - outputs = self.model.eval_step(**inputs) - - if has_labels: - step_eval_loss, logits = outputs[:2] - eval_losses += [step_eval_loss.mean().item()] - else: - logits = outputs[0] - - if preds is None: - preds = logits.detach().cpu().numpy() - else: - preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) - if inputs.get("labels") is not None: - if label_ids is None: - label_ids = inputs["labels"].detach().cpu().numpy() - else: - label_ids = np.append(label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) - - if self.compute_metrics is not None and preds is not None and label_ids is not None: - metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) - else: - metrics = {} - if len(eval_losses) > 0: - metrics["loss"] = np.mean(eval_losses) - - return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) diff --git a/orttraining/orttraining/test/python/utils_multiple_choice.py b/orttraining/orttraining/test/python/utils_multiple_choice.py deleted file mode 100644 index e0febaf2d6334..0000000000000 --- a/orttraining/orttraining/test/python/utils_multiple_choice.py +++ /dev/null @@ -1,269 +0,0 @@ -# adapted from run_multiple_choice.py of huggingface transformers -# https://github.com/huggingface/transformers/blob/master/examples/multiple-choice/utils_multiple_choice.py - -import csv -import glob # noqa: F401 -import json # noqa: F401 -import logging -import os -from dataclasses import dataclass -from enum import Enum -from typing import List, Optional - -import torch -import tqdm -from filelock import FileLock -from torch.utils.data.dataset import Dataset -from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available # noqa: F401 - -logger = logging.getLogger(__name__) - - -@dataclass(frozen=True) -class InputExample: - """ - A single training/test example for multiple choice - - Args: - example_id: Unique id for the example. - question: string. The untokenized text of the second sequence (question). - contexts: list of str. The untokenized text of the first sequence (context of corresponding question). - endings: list of str. multiple choice's options. Its length must be equal to contexts' length. - label: (Optional) string. The label of the example. This should be - specified for train and dev examples, but not for test examples. - """ - - example_id: str - question: str - contexts: List[str] - endings: List[str] - label: Optional[str] - - -@dataclass(frozen=True) -class InputFeatures: - """ - A single set of features of data. - Property names are the same names as the corresponding inputs to a model. - """ - - example_id: str - input_ids: List[List[int]] - attention_mask: Optional[List[List[int]]] - token_type_ids: Optional[List[List[int]]] - label: Optional[int] - - -class Split(Enum): - train = "train" - dev = "dev" - test = "test" - - -class DataProcessor: - """Base class for data converters for multiple choice data sets.""" - - def get_train_examples(self, data_dir): - """Gets a collection of `InputExample`s for the train set.""" - raise NotImplementedError() - - def get_dev_examples(self, data_dir): - """Gets a collection of `InputExample`s for the dev set.""" - raise NotImplementedError() - - def get_test_examples(self, data_dir): - """Gets a collection of `InputExample`s for the test set.""" - raise NotImplementedError() - - def get_labels(self): - """Gets the list of labels for this data set.""" - raise NotImplementedError() - - -class MultipleChoiceDataset(Dataset): - """ - This will be superseded by a framework-agnostic approach - soon. - """ - - features: List[InputFeatures] - - def __init__( - self, - data_dir: str, - tokenizer: PreTrainedTokenizer, - task: str, - processor: DataProcessor, - max_seq_length: Optional[int] = None, - overwrite_cache=False, - mode: Split = Split.train, - ): - cached_features_file = os.path.join( - data_dir, - "cached_{}_{}_{}_{}".format( - mode.value, - tokenizer.__class__.__name__, - str(max_seq_length), - task, - ), - ) - - # Make sure only the first process in distributed training processes the dataset, - # and the others will use the cache. - lock_path = cached_features_file + ".lock" - with FileLock(lock_path): - if os.path.exists(cached_features_file) and not overwrite_cache: - logger.info(f"Loading features from cached file {cached_features_file}") - self.features = torch.load(cached_features_file) - else: - logger.info(f"Creating features from dataset file at {data_dir}") - label_list = processor.get_labels() - if mode == Split.dev: - examples = processor.get_dev_examples(data_dir) - elif mode == Split.test: - examples = processor.get_test_examples(data_dir) - else: - examples = processor.get_train_examples(data_dir) - logger.info("Training examples: %s", len(examples)) - # TODO clean up all this to leverage built-in features of tokenizers - self.features = convert_examples_to_features( - examples, - label_list, - max_seq_length, - tokenizer, - pad_on_left=bool(tokenizer.padding_side == "left"), - pad_token=tokenizer.pad_token_id, - pad_token_segment_id=tokenizer.pad_token_type_id, - ) - logger.info("Saving features into cached file %s", cached_features_file) - torch.save(self.features, cached_features_file) - - def __len__(self): - return len(self.features) - - def __getitem__(self, i) -> InputFeatures: - return self.features[i] - - -class SwagProcessor(DataProcessor): - """Processor for the SWAG data set.""" - - def get_train_examples(self, data_dir): - """See base class.""" - logger.info(f"LOOKING AT {data_dir} train") - return self._create_examples(self._read_csv(os.path.join(data_dir, "train.csv")), "train") - - def get_dev_examples(self, data_dir): - """See base class.""" - logger.info(f"LOOKING AT {data_dir} dev") - return self._create_examples(self._read_csv(os.path.join(data_dir, "val.csv")), "dev") - - def get_test_examples(self, data_dir): - """See base class.""" - logger.info(f"LOOKING AT {data_dir} dev") - raise ValueError( - "For swag testing, the input file does not contain a label column. It can not be tested in current code" - "setting!" - ) - return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test") - - def get_labels(self): - """See base class.""" - return ["0", "1", "2", "3"] - - def _read_csv(self, input_file): - with open(input_file, encoding="utf-8") as f: - return list(csv.reader(f)) - - def _create_examples(self, lines: List[List[str]], type: str): - """Creates examples for the training and dev sets.""" - if type == "train" and lines[0][-1] != "label": - raise ValueError("For training, the input file must contain a label column.") - - examples = [ - InputExample( - example_id=line[2], - question=line[5], # in the swag dataset, the - # common beginning of each - # choice is stored in "sent2". - contexts=[line[4], line[4], line[4], line[4]], - endings=[line[7], line[8], line[9], line[10]], - label=line[11], - ) - for line in lines[1:] # we skip the line with the column names - ] - - return examples - - -def convert_examples_to_features( - examples: List[InputExample], - label_list: List[str], - max_length: int, - tokenizer: PreTrainedTokenizer, - pad_token_segment_id=0, - pad_on_left=False, - pad_token=0, - mask_padding_with_zero=True, -) -> List[InputFeatures]: - """ - Loads a data file into a list of `InputFeatures` - """ - - label_map = {label: i for i, label in enumerate(label_list)} - - features = [] - for ex_index, example in tqdm.tqdm(enumerate(examples), desc="convert examples to features"): - if ex_index % 10000 == 0: - logger.info("Writing example %d of %d" % (ex_index, len(examples))) - choices_inputs = [] - for _ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)): - text_a = context - if example.question.find("_") != -1: - # this is for cloze question - text_b = example.question.replace("_", ending) - else: - text_b = example.question + " " + ending - - inputs = tokenizer.encode_plus( - text_a, - text_b, - add_special_tokens=True, - max_length=max_length, - pad_to_max_length=True, - return_overflowing_tokens=True, - ) - if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0: - logger.info( - "Attention! you are cropping tokens (swag task is ok). " - "If you are training ARC and RACE and you are poping question + options," - "you need to try to use a bigger max seq length!" - ) - - choices_inputs.append(inputs) - - label = label_map[example.label] - - input_ids = [x["input_ids"] for x in choices_inputs] - attention_mask = ( - [x["attention_mask"] for x in choices_inputs] if "attention_mask" in choices_inputs[0] else None - ) - token_type_ids = ( - [x["token_type_ids"] for x in choices_inputs] if "token_type_ids" in choices_inputs[0] else None - ) - - features.append( - InputFeatures( - example_id=example.example_id, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - label=label, - ) - ) - - for f in features[:2]: - logger.info("*** Example ***") - logger.info("feature: %s" % f) - - return features diff --git a/orttraining/pytorch_frontend_examples/mnist_training.py b/orttraining/pytorch_frontend_examples/mnist_training.py deleted file mode 100644 index dc9b3f654400c..0000000000000 --- a/orttraining/pytorch_frontend_examples/mnist_training.py +++ /dev/null @@ -1,200 +0,0 @@ -## This code is from https://github.com/pytorch/examples/blob/master/mnist/main.py -## with modification to do training using onnxruntime as backend on cuda device. -## A private PyTorch build from https://aiinfra.visualstudio.com/Lotus/_git/pytorch (ORTTraining branch) is needed to run the demo. - -## Model testing is not complete. - -import argparse -import os - -import numpy as np # noqa: F401 -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim # noqa: F401 -from mpi4py import MPI -from torchvision import datasets, transforms - -from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer - -try: # noqa: SIM105 - from onnxruntime.capi._pybind_state import set_cuda_device_id -except ImportError: - pass - - -class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, x): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - return out - - -def my_loss(x, target): - return F.nll_loss(F.log_softmax(x, dim=1), target) - - -def train_with_trainer(args, trainer, device, train_loader, epoch): - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - - learning_rate = torch.tensor([args.lr]) - loss = trainer.train_step(data, target, learning_rate) - - # Since the output corresponds to [loss_desc, probability_desc], the first value is taken as loss. - if batch_idx % args.log_interval == 0: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, - batch_idx * len(data), - len(train_loader.dataset), - 100.0 * batch_idx / len(train_loader), - loss[0], - ) - ) - - -# TODO: comple this once ORT training can do evaluation. -def test_with_trainer(args, trainer, device, test_loader): - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - output = F.log_softmax(trainer.eval_step(data, fetches=["probability"]), dim=1) - test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print( - "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( - test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) - ) - ) - - -def mnist_model_description(): - input_desc = IODescription("input1", ["batch", 784], torch.float32) - label_desc = IODescription( - "label", - [ - "batch", - ], - torch.int64, - num_classes=10, - ) - loss_desc = IODescription("loss", [], torch.float32) - probability_desc = IODescription("probability", ["batch", 10], torch.float32) - return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc]) - - -def main(): - # Training settings - parser = argparse.ArgumentParser(description="PyTorch MNIST Example") - parser.add_argument( - "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" - ) - parser.add_argument("--epochs", type=int, default=10, metavar="N", help="number of epochs to train (default: 10)") - parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=10, - metavar="N", - help="how many batches to wait before logging training status", - ) - - args = parser.parse_args() - use_cuda = not args.no_cuda and torch.cuda.is_available() - - torch.manual_seed(args.seed) - - kwargs = {"num_workers": 0, "pin_memory": True} - train_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "../data", - train=True, - download=True, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.batch_size, - shuffle=True, - **kwargs, - ) - test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "../data", - train=False, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.test_batch_size, - shuffle=True, - **kwargs, - ) - - comm = MPI.COMM_WORLD - args.local_rank = ( - int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) if ("OMPI_COMM_WORLD_LOCAL_RANK" in os.environ) else 0 - ) - args.world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) if ("OMPI_COMM_WORLD_RANK" in os.environ) else 0 - args.world_size = comm.Get_size() - if use_cuda: - torch.cuda.set_device(args.local_rank) - device = torch.device("cuda", args.local_rank) - args.n_gpu = 1 - set_cuda_device_id(args.local_rank) - else: - device = torch.device("cpu") - - input_size = 784 - hidden_size = 500 - num_classes = 10 - model = NeuralNet(input_size, hidden_size, num_classes) - - model_desc = mnist_model_description() - # use log_interval as gradient accumulate steps - trainer = ORTTrainer( - model, - my_loss, - model_desc, - "SGDOptimizer", - None, - IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ), - device, - 1, - args.world_rank, - args.world_size, - use_mixed_precision=False, - allreduce_post_accumulation=True, - ) - print("\nBuild ort model done.") - - for epoch in range(1, args.epochs + 1): - train_with_trainer(args, trainer, device, train_loader, epoch) - test_with_trainer(args, trainer, device, test_loader) - - -if __name__ == "__main__": - main() diff --git a/samples/python/training/orttrainer/mnist/mnist_original.onnx b/samples/python/training/orttrainer/mnist/mnist_original.onnx deleted file mode 100644 index 15931affb5ccf..0000000000000 Binary files a/samples/python/training/orttrainer/mnist/mnist_original.onnx and /dev/null differ diff --git a/samples/python/training/orttrainer/mnist/ort_mnist.py b/samples/python/training/orttrainer/mnist/ort_mnist.py deleted file mode 100644 index 8f8ccf373ccf6..0000000000000 --- a/samples/python/training/orttrainer/mnist/ort_mnist.py +++ /dev/null @@ -1,174 +0,0 @@ -# This code is from https://github.com/pytorch/examples/blob/master/mnist/main.py -# with modification to do training using onnxruntime as backend on cuda device. - -import argparse -import os - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchvision import datasets, transforms - -import onnxruntime -from onnxruntime.training import ORTTrainer, ORTTrainerOptions, optim - - -# Pytorch model -class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, input1): - out = self.fc1(input1) - out = self.relu(out) - out = self.fc2(out) - return out - - -# ONNX Runtime training -def mnist_model_description(): - return { - "inputs": [("input1", ["batch", 784]), ("label", ["batch"])], - "outputs": [("loss", [], True), ("probability", ["batch", 10])], - } - - -def my_loss(x, target): - return F.nll_loss(F.log_softmax(x, dim=1), target) - - -# Helpers -def train(log_interval, trainer, device, train_loader, epoch, train_steps): - for batch_idx, (data, target) in enumerate(train_loader): - if batch_idx == train_steps: - break - - # Fetch data - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - - # Train step - loss, prob = trainer.train_step(data, target) - - # Stats - if batch_idx % log_interval == 0: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, batch_idx * len(data), len(train_loader.dataset), 100.0 * batch_idx / len(train_loader), loss - ) - ) - - -def test(trainer, device, test_loader): - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - - # Using fetches around without eval_step to not pass 'target' as input - trainer._train_step_info.fetches = ["probability"] - output = F.log_softmax(trainer.eval_step(data), dim=1) - trainer._train_step_info.fetches = [] - - # Stats - test_loss += F.nll_loss(output, target, reduction="sum").item() - pred = output.argmax(dim=1, keepdim=True) - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print( - "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( - test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) - ) - ) - - -def main(): - # Training settings - parser = argparse.ArgumentParser(description="ONNX Runtime MNIST Example") - parser.add_argument( - "--train-steps", - type=int, - default=-1, - metavar="N", - help="number of steps to train. Set -1 to run through whole dataset (default: -1)", - ) - parser.add_argument( - "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" - ) - parser.add_argument("--epochs", type=int, default=1, metavar="N", help="number of epochs to train (default: 1)") - parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=10, - metavar="N", - help="how many batches to wait before logging training status", - ) - parser.add_argument("--save-path", type=str, default="", help="Path for Saving the current Model state") - - # Basic setup - args = parser.parse_args() - if not args.no_cuda and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - torch.manual_seed(args.seed) - onnxruntime.set_seed(args.seed) - - # Data loader - train_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "./data", - train=True, - download=True, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.batch_size, - shuffle=True, - ) - - if args.test_batch_size > 0: - test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "./data", - train=False, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.test_batch_size, - shuffle=True, - ) - - # Modeling - model = NeuralNet(784, 500, 10) - model_desc = mnist_model_description() - optim_config = optim.SGDConfig(lr=args.lr) - opts = {"device": {"id": device}} - opts = ORTTrainerOptions(opts) - - trainer = ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - - # Train loop - for epoch in range(1, args.epochs + 1): - train(args.log_interval, trainer, device, train_loader, epoch, args.train_steps) - if args.test_batch_size > 0: - test(trainer, device, test_loader) - - # Save model - if args.save_path: - torch.save(model.state_dict(), os.path.join(args.save_path, "mnist_cnn.pt")) - - -if __name__ == "__main__": - main() diff --git a/samples/python/training/orttrainer/mnist/pytorch_mnist.py b/samples/python/training/orttrainer/mnist/pytorch_mnist.py deleted file mode 100644 index 2e451d85f62e8..0000000000000 --- a/samples/python/training/orttrainer/mnist/pytorch_mnist.py +++ /dev/null @@ -1,157 +0,0 @@ -import argparse -import os - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torchvision import datasets, transforms - - -# Pytorch model -class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, input1): - out = self.fc1(input1) - out = self.relu(out) - out = self.fc2(out) - return out - - -def my_loss(x, target, is_train=True): - if is_train: - return F.nll_loss(F.log_softmax(x, dim=1), target) - else: - return F.nll_loss(F.log_softmax(x, dim=1), target, reduction="sum") - - -# Helpers -def train(args, model, device, train_loader, optimizer, epoch): - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - if batch_idx == args.train_steps: - break - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - optimizer.zero_grad() - output = model(data) - loss = my_loss(output, target) - loss.backward() - optimizer.step() - if batch_idx % args.log_interval == 0: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, - batch_idx * len(data), - len(train_loader.dataset), - 100.0 * batch_idx / len(train_loader), - loss.item(), - ) - ) - - -def test(model, device, test_loader): - model.eval() - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - output = model(data) - # Stats - test_loss += my_loss(output, target, False).item() - pred = output.argmax(dim=1, keepdim=True) - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print( - "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( - test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) - ) - ) - - -def main(): - # Training settings - parser = argparse.ArgumentParser(description="PyTorch MNIST Example") - parser.add_argument( - "--train-steps", - type=int, - default=-1, - metavar="N", - help="number of steps to train. Set -1 to run through whole dataset (default: -1)", - ) - parser.add_argument( - "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" - ) - parser.add_argument("--epochs", type=int, default=1, metavar="N", help="number of epochs to train (default: 1)") - parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=10, - metavar="N", - help="how many batches to wait before logging training status", - ) - parser.add_argument("--save-path", type=str, default="", help="Path for Saving the current Model") - - # Basic setup - args = parser.parse_args() - if not args.no_cuda and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - torch.manual_seed(args.seed) - - # Data loader - train_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "./data", - train=True, - download=True, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.batch_size, - shuffle=True, - ) - - if args.test_batch_size > 0: - test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "./data", - train=False, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.test_batch_size, - shuffle=True, - ) - - # Modeling - model = NeuralNet(784, 500, 10).to(device) - optimizer = optim.SGD(model.parameters(), lr=args.lr) - - # Train loop - for epoch in range(1, args.epochs + 1): - train(args, model, device, train_loader, optimizer, epoch) - if args.test_batch_size > 0: - test(model, device, test_loader) - - # Save model - if args.save_path: - torch.save(model.state_dict(), os.path.join(args.save_path, "mnist_cnn.pt")) - - -if __name__ == "__main__": - main() diff --git a/samples/python/training/orttrainer/pytorch_transformer/README.md b/samples/python/training/orttrainer/pytorch_transformer/README.md deleted file mode 100644 index cda8cba6ca0ad..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/README.md +++ /dev/null @@ -1,33 +0,0 @@ -# TransformerModel example - -This example was adapted from Pytorch's [Sequence-to-Sequence Modeling with nn.Transformer and TorchText](https://pytorch.org/tutorials/beginner/transformer_tutorial.html) tutorial - -## Requirements - -* PyTorch 1.6+ -* TorchText 0.6+ -* ONNX Runtime 1.5+ - -## Running PyTorch version - -```bash -python pt_train.py -``` - -## Running ONNX Runtime version - -```bash -python ort_train.py -``` - -## Optional arguments - -| Argument | Description | Default | -| :---------------- | :-----------------------------------------------------: | --------: | -| --batch-size | input batch size for training | 20 | -| --test-batch-size | input batch size for testing | 20 | -| --epochs | number of epochs to train | 2 | -| --lr | learning rate | 0.001 | -| --no-cuda | disables CUDA training | False | -| --seed | random seed | 1 | -| --log-interval | how many batches to wait before logging training status | 200 | diff --git a/samples/python/training/orttrainer/pytorch_transformer/ort_train.py b/samples/python/training/orttrainer/pytorch_transformer/ort_train.py deleted file mode 100644 index 551e878cc9035..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/ort_train.py +++ /dev/null @@ -1,89 +0,0 @@ -import argparse - -import torch -from ort_utils import my_loss, transformer_model_description_dynamic_axes -from pt_model import TransformerModel -from utils import get_batch, prepare_data - -import onnxruntime - - -def train(trainer, data_source, device, epoch, args, bptt=35): - total_loss = 0.0 - for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)): - data, targets = get_batch(data_source, i) - - loss, pred = trainer.train_step(data, targets) - total_loss += loss.item() - if batch % args.log_interval == 0 and batch > 0: - cur_loss = total_loss / args.log_interval - print( - "epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}".format( - epoch, batch, len(data_source) // bptt, cur_loss - ) - ) - total_loss = 0 - - -def evaluate(trainer, data_source, bptt=35): - total_loss = 0.0 - with torch.no_grad(): - for i in range(0, data_source.size(0) - 1, bptt): - data, targets = get_batch(data_source, i) - loss, pred = trainer.eval_step(data, targets) - total_loss += len(data) * loss.item() - return total_loss / (len(data_source) - 1) - - -if __name__ == "__main__": - # Training settings - parser = argparse.ArgumentParser(description="PyTorch TransformerModel example") - parser.add_argument( - "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=20, metavar="N", help="input batch size for testing (default: 20)" - ) - parser.add_argument("--epochs", type=int, default=2, metavar="N", help="number of epochs to train (default: 2)") - parser.add_argument("--lr", type=float, default=0.001, metavar="LR", help="learning rate (default: 0.001)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=200, - metavar="N", - help="how many batches to wait before logging training status (default: 200)", - ) - - # Basic setup - args = parser.parse_args() - if not args.no_cuda and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - torch.manual_seed(args.seed) - onnxruntime.set_seed(args.seed) - - # Model - optim_config = onnxruntime.training.optim.SGDConfig(lr=args.lr) - model_desc = transformer_model_description_dynamic_axes() - model = TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device) - - # Preparing data - train_data, val_data, test_data = prepare_data(device, args.batch_size, args.test_batch_size) - trainer = onnxruntime.training.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss) - - # Train - for epoch in range(1, args.epochs + 1): - train(trainer, train_data, device, epoch, args) - val_loss = evaluate(trainer, val_data) - print("-" * 89) - print(f"| end of epoch {epoch:3d} | valid loss {val_loss:5.2f} | ") - print("-" * 89) - - # Evaluate - test_loss = evaluate(trainer, test_data) - print("=" * 89) - print(f"| End of training | test loss {test_loss:5.2f}") - print("=" * 89) diff --git a/samples/python/training/orttrainer/pytorch_transformer/ort_utils.py b/samples/python/training/orttrainer/pytorch_transformer/ort_utils.py deleted file mode 100644 index 73992f5596f5f..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/ort_utils.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch - -from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription -from onnxruntime.capi.ort_trainer import ModelDescription as Legacy_ModelDescription - - -def my_loss(x, target): - x = x.view(-1, 28785) - return torch.nn.CrossEntropyLoss()(x, target) - - -def transformer_model_description(bptt=35, batch_size=20, ntokens=28785): - model_desc = { - "inputs": [("input1", [bptt, batch_size]), ("label", [bptt * batch_size])], - "outputs": [("loss", [], True), ("predictions", [bptt, batch_size, ntokens])], - } - return model_desc - - -def transformer_model_description_dynamic_axes(ntokens=28785): - model_desc = { - "inputs": [("input1", ["bptt", "batch_size"]), ("label", ["bptt_x_batch_size"])], - "outputs": [("loss", [], True), ("predictions", ["bptt", "batch_size", ntokens])], - } - return model_desc - - -def legacy_transformer_model_description(bptt=35, batch_size=20, ntokens=28785): - input_desc = Legacy_IODescription("input1", [bptt, batch_size]) - label_desc = Legacy_IODescription("label", [bptt * batch_size]) - loss_desc = Legacy_IODescription("loss", []) - predictions_desc = Legacy_IODescription("predictions", [bptt, batch_size, ntokens]) - return ( - Legacy_ModelDescription([input_desc, label_desc], [loss_desc, predictions_desc]), - Legacy_IODescription("__learning_rate", [1]), - ) - - -def legacy_transformer_model_description_dynamic_axes(ntokens=28785): - input_desc = Legacy_IODescription("input1", ["bptt", "batch_size"]) - label_desc = Legacy_IODescription("label", ["bptt_x_batch_size"]) - loss_desc = Legacy_IODescription("loss", []) - predictions_desc = Legacy_IODescription("predictions", ["bptt", "batch_size", ntokens]) - return ( - Legacy_ModelDescription([input_desc, label_desc], [loss_desc, predictions_desc]), - Legacy_IODescription("__learning_rate", [1]), - ) diff --git a/samples/python/training/orttrainer/pytorch_transformer/pt_model.py b/samples/python/training/orttrainer/pytorch_transformer/pt_model.py deleted file mode 100644 index 4f2e03192c6cf..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/pt_model.py +++ /dev/null @@ -1,62 +0,0 @@ -import math - -import torch -import torch.nn as nn - - -class TransformerModel(nn.Module): - def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): - super().__init__() - from torch.nn import TransformerEncoder, TransformerEncoderLayer - - self.model_type = "Transformer" - self.input1_mask = None - self.pos_encoder = PositionalEncoding(ninp, dropout) - encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) - self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) - self.encoder = nn.Embedding(ntoken, ninp) - self.ninp = ninp - self.decoder = nn.Linear(ninp, ntoken) - - self.init_weights() - - def _generate_square_subsequent_mask(self, sz): - mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) - mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, 0.0) - return mask - - def init_weights(self): - initrange = 0.1 - self.encoder.weight.data.uniform_(-initrange, initrange) - self.decoder.bias.data.zero_() - self.decoder.weight.data.uniform_(-initrange, initrange) - - def forward(self, input1): - if self.input1_mask is None or self.input1_mask.size(0) != input1.size(0): - device = input1.device - mask = self._generate_square_subsequent_mask(input1.size(0)).to(device) - self.input1_mask = mask - - input1 = self.encoder(input1) * math.sqrt(self.ninp) - input1 = self.pos_encoder(input1) - output = self.transformer_encoder(input1, self.input1_mask) - output = self.decoder(output) - return output - - -class PositionalEncoding(nn.Module): - def __init__(self, d_model, dropout=0.1, max_len=5000): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - - pe = torch.zeros(max_len, d_model) - position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0).transpose(0, 1) - self.register_buffer("pe", pe) - - def forward(self, x): - x = x + self.pe[: x.size(0), :] - return self.dropout(x) diff --git a/samples/python/training/orttrainer/pytorch_transformer/pt_train.py b/samples/python/training/orttrainer/pytorch_transformer/pt_train.py deleted file mode 100644 index a197fb50357e9..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/pt_train.py +++ /dev/null @@ -1,94 +0,0 @@ -import argparse - -import torch -import torch.nn as nn -from pt_model import TransformerModel -from utils import get_batch, prepare_data - - -def train(model, data_source, device, epoch, args, bptt=35): - total_loss = 0.0 - model.train() - for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)): - data, targets = get_batch(data_source, i) - - optimizer.zero_grad() - output = model(data) - loss = criterion(output.view(-1, 28785), targets) - loss.backward() - optimizer.step() - - total_loss += loss.item() - if batch % args.log_interval == 0 and batch > 0: - cur_loss = total_loss / args.log_interval - print( - "epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}".format( - epoch, batch, len(data_source) // bptt, cur_loss - ) - ) - total_loss = 0 - - -def evaluate(model, data_source, criterion, bptt=35): - total_loss = 0.0 - model.eval() - with torch.no_grad(): - for i in range(0, data_source.size(0) - 1, bptt): - data, targets = get_batch(data_source, i) - output = model(data) - output_flat = output.view(-1, 28785) - total_loss += len(data) * criterion(output_flat, targets).item() - return total_loss / (len(data_source) - 1) - - -if __name__ == "__main__": - # Training settings - parser = argparse.ArgumentParser(description="PyTorch TransformerModel example") - parser.add_argument( - "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=20, metavar="N", help="input batch size for testing (default: 20)" - ) - parser.add_argument("--epochs", type=int, default=2, metavar="N", help="number of epochs to train (default: 2)") - parser.add_argument("--lr", type=float, default=0.001, metavar="LR", help="learning rate (default: 0.001)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=200, - metavar="N", - help="how many batches to wait before logging training status (default: 200)", - ) - - # Basic setup - args = parser.parse_args() - if not args.no_cuda and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - torch.manual_seed(args.seed) - - # Model - criterion = nn.CrossEntropyLoss() - lr = 0.001 - model = TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device) - optimizer = torch.optim.SGD(model.parameters(), lr=lr) - - # Preparing data - train_data, val_data, test_data = prepare_data(device, args.batch_size, args.test_batch_size) - - # Train - for epoch in range(1, args.epochs + 1): - train(model, train_data, device, epoch, args) - val_loss = evaluate(model, val_data, criterion) - print("-" * 89) - print(f"| end of epoch {epoch:3d} | valid loss {val_loss:5.2f} | ") - print("-" * 89) - - # Evaluate - test_loss = evaluate(model, test_data, criterion) - print("=" * 89) - print(f"| End of training | test loss {test_loss:5.2f}") - print("=" * 89) diff --git a/samples/python/training/orttrainer/pytorch_transformer/utils.py b/samples/python/training/orttrainer/pytorch_transformer/utils.py deleted file mode 100644 index 3be8b6cf3f420..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/utils.py +++ /dev/null @@ -1,59 +0,0 @@ -import os - -import torch -from torchtext.data.utils import get_tokenizer -from torchtext.utils import download_from_url, extract_archive -from torchtext.vocab import build_vocab_from_iterator - - -def batchify(data, bsz, device): - # Divide the dataset into bsz parts. - nbatch = data.size(0) // bsz - # Trim off any extra elements that wouldn't cleanly fit (remainders). - data = data.narrow(0, 0, nbatch * bsz) - # Evenly divide the data across the bsz batches. - data = data.view(bsz, -1).t().contiguous() - return data.to(device) - - -def get_batch(source, i, bptt=35): - seq_len = min(bptt, len(source) - 1 - i) - data = source[i : i + seq_len] - target = source[i + 1 : i + 1 + seq_len].view(-1) - return data, target - - -def prepare_data(device="cpu", train_batch_size=20, eval_batch_size=20, data_dir=None): - url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip" - - download_path = ".data_wikitext_2_v1" - extract_path = None - if data_dir: - download_path = os.path.join(data_dir, "download") - os.makedirs(download_path, exist_ok=True) - download_path = os.path.join(download_path, "wikitext-2-v1.zip") - - extract_path = os.path.join(data_dir, "extracted") - os.makedirs(extract_path, exist_ok=True) - - test_filepath, valid_filepath, train_filepath = extract_archive( - download_from_url(url, root=download_path), to_path=extract_path - ) - tokenizer = get_tokenizer("basic_english") - vocab = build_vocab_from_iterator(map(tokenizer, iter(open(train_filepath, encoding="utf8")))) # noqa: SIM115 - - def data_process(raw_text_iter): - data = [torch.tensor([vocab[token] for token in tokenizer(item)], dtype=torch.long) for item in raw_text_iter] - return torch.cat(tuple(filter(lambda t: t.numel() > 0, data))) - - train_data = data_process(iter(open(train_filepath, encoding="utf8"))) # noqa: SIM115 - val_data = data_process(iter(open(valid_filepath, encoding="utf8"))) # noqa: SIM115 - test_data = data_process(iter(open(test_filepath, encoding="utf8"))) # noqa: SIM115 - - device = torch.device(device) - - train_data = batchify(train_data, train_batch_size, device) - val_data = batchify(val_data, eval_batch_size, device) - test_data = batchify(test_data, eval_batch_size, device) - - return train_data, val_data, test_data diff --git a/setup.py b/setup.py index 9eca9845c9e8b..798c8c4b2895b 100644 --- a/setup.py +++ b/setup.py @@ -196,7 +196,7 @@ def run(self): "libcublasLt.so.11", "libcublasLt.so.12", "libcudart.so.11.0", - "libcudart.so.12.0", + "libcudart.so.12", "libcudnn.so.8", "libcufft.so.10", "libcufft.so.11", @@ -219,6 +219,10 @@ def run(self): "librocm_smi64.so.5", "libroctracer64.so.4", "libtinfo.so.6", + "libmigraphx_c.so.3", + "libmigraphx.so.2", + "libmigraphx_onnx.so.2", + "libmigraphx_tf.so.2", ] tensorrt_dependencies = ["libnvinfer.so.8", "libnvinfer_plugin.so.8", "libnvonnxparser.so.8"] @@ -394,7 +398,6 @@ def finalize_options(self): "onnxruntime", "onnxruntime.backend", "onnxruntime.capi", - "onnxruntime.capi.training", "onnxruntime.datasets", "onnxruntime.tools", "onnxruntime.tools.mobile_helpers", diff --git a/tools/android_custom_build/Dockerfile b/tools/android_custom_build/Dockerfile index 66b6a36e5a8c0..754a6633b0c62 100644 --- a/tools/android_custom_build/Dockerfile +++ b/tools/android_custom_build/Dockerfile @@ -55,7 +55,7 @@ WORKDIR /workspace # install Android SDK and tools ENV ANDROID_HOME=~/android-sdk -ENV NDK_VERSION=26.0.10792818 +ENV NDK_VERSION=26.1.10909125 ENV ANDROID_NDK_HOME=${ANDROID_HOME}/ndk/${NDK_VERSION} RUN aria2c -q -d /tmp -o cmdline-tools.zip \ diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 80bc7a16b2a0c..9d200f4b108c1 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -14,6 +14,15 @@ import sys from pathlib import Path + +def version_to_tuple(version: str) -> tuple: + v = [] + for s in version.split("."): + with contextlib.suppress(ValueError): + v.append(int(s)) + return tuple(v) + + SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) REPO_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..")) @@ -1084,6 +1093,12 @@ def generate_build_tree( if args.use_cuda: nvcc_threads = number_of_nvcc_threads(args) cmake_args.append("-Donnxruntime_NVCC_THREADS=" + str(nvcc_threads)) + if not disable_float8_types and args.cuda_version: + if version_to_tuple(args.cuda_version) < (11, 8): + raise BuildError( + f"Float 8 types require CUDA>=11.8. They must be disabled on CUDA=={args.cuda_version}. " + f"Add '--disable_types float8' to your command line. See option disable_types." + ) if args.use_rocm: cmake_args.append("-Donnxruntime_ROCM_HOME=" + rocm_home) cmake_args.append("-Donnxruntime_ROCM_VERSION=" + args.rocm_version) @@ -1171,9 +1186,9 @@ def generate_build_tree( "-Donnxruntime_USE_OPENVINO_AUTO=" + ("ON" if args.use_openvino.startswith("AUTO") else "OFF"), ] - # TensorRT and OpenVINO providers currently only support + # VitisAI and OpenVINO providers currently only support # full_protobuf option. - if args.use_full_protobuf or args.use_tensorrt or args.use_openvino or args.use_vitisai or args.gen_doc: + if args.use_full_protobuf or args.use_openvino or args.use_vitisai or args.gen_doc: cmake_args += ["-Donnxruntime_USE_FULL_PROTOBUF=ON", "-DProtobuf_USE_STATIC_LIBS=ON"] if args.use_tvm and args.llvm_path is not None: diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 3b22d840b9590..0eccd71e47f46 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -201,7 +201,7 @@ stages: buildArch: x64 msbuildPlatform: x64 packageName: x64-cuda - buildparameter: --use_cuda --cuda_version=11.8 --cuda_home=$(Agent.TempDirectory)\v11.8 --enable_onnx_tests --enable_wcos --build_java --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" ${{parameters.AdditionalBuildFlag}} + buildparameter: --use_cuda --cuda_home=$(Agent.TempDirectory)\v11.8 --enable_onnx_tests --enable_wcos --build_java --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" ${{parameters.AdditionalBuildFlag}} runTests: ${{ parameters.RunOnnxRuntimeTests }} buildJava: true java_artifact_id: onnxruntime_gpu @@ -217,7 +217,7 @@ stages: buildArch: x64 msbuildPlatform: x64 packageName: x64-tensorrt - buildparameter: --use_tensorrt --tensorrt_home="C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8" --cuda_version=11.8 --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8" --enable_onnx_tests --enable_wcos --build_java --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" + buildparameter: --use_tensorrt --tensorrt_home="C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8" --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8" --enable_onnx_tests --enable_wcos --build_java --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" runTests: ${{ parameters.RunOnnxRuntimeTests }} buildJava: true java_artifact_id: onnxruntime_gpu @@ -1094,7 +1094,6 @@ stages: DoNugetPack : 'true' DoCompliance: 'false' DoEsrp: ${{ parameters.DoEsrp }} - OrtPackageId: 'Microsoft.ML.OnnxRuntime.DirectML' NuPackScript: | msbuild $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj /p:Configuration=RelWithDebInfo /t:CreatePackage /p:OrtPackageId=Microsoft.ML.OnnxRuntime.DirectML /p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} copy $(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo\*.nupkg $(Build.ArtifactStagingDirectory) diff --git a/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml index 60f2786bdd856..f5472a49c5148 100644 --- a/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml @@ -34,7 +34,7 @@ jobs: pool: vmImage: 'macOS-13' variables: - MACOSX_DEPLOYMENT_TARGET: '10.14' + MACOSX_DEPLOYMENT_TARGET: '11.0' TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] CCACHE_DIR: '$(Pipeline.Workspace)/ccache' timeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml index 588b5d049ee3c..b98837078b2d5 100644 --- a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml @@ -41,9 +41,14 @@ stages: parameters: NpmPackagingMode: ${{ variables.NpmPackagingMode }} IsReleasePipeline: true - PoolName: 'Azure-Pipelines-EO-Windows2022-aiinfra' + PoolName: 'onnxruntime-Ubuntu2004-AMD-CPU' PackageName: 'onnxruntime-web' ExtraBuildArgs: '' + UseWebPoolName: true + RunWebGpuTestsForDebugBuild: false + RunWebGpuTestsForReleaseBuild: true + WebGpuPoolName: 'onnxruntime-Win2022-webgpu-A10' + WebCpuPoolName: 'Azure-Pipelines-EO-Windows2022-aiinfra' - template: templates/react-native-ci.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index 2d92108efb46d..4e7093f04a59f 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -15,7 +15,7 @@ parameters: EnvSetupScript: 'setup_env.bat' AgentPool: 'onnxruntime-Win-CPU-2022' AgentDemands: [] - OrtPackageId: Microsoft.ML.OnnxRuntime + OrtPackageId: Microsoft.ML.OnnxRuntime.DirectML BuildConfigurations: ['RelWithDebInfo'] # Options: Debug, RelWithDebInfo RunTests : 'true' EnableLto: true diff --git a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml index ed84b514fbbcf..8d02a5e5809a2 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml @@ -108,6 +108,7 @@ jobs: FETCHCONTENT_TRY_FIND_PACKAGE_MODE=NEVER \ --use_cache \ --use_rocm \ + --use_migraphx \ --rocm_version=$(RocmVersion) \ --rocm_home ${ROCM_HOME} \ --nccl_home ${ROCM_HOME}\ diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml new file mode 100644 index 0000000000000..422fb33eec5de --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml @@ -0,0 +1,22 @@ +trigger: none + +resources: + repositories: + - repository: manylinux + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 + +stages: +- template: templates/py-packaging-training-cuda-stage.yml + parameters: + build_py_parameters: --enable_training --update --build + torch_version: '2.1.0' + opset_version: '15' + cuda_version: '12.2' + cmake_cuda_architectures: 70;75;80;86;90 + docker_file: Dockerfile.manylinux2_28_training_cuda12_2 + agent_pool: Onnxruntime-Linux-GPU + upload_wheel: 'yes' + debug_build: false diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index 2a94499c7a268..6fdb255606a19 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -4,9 +4,11 @@ stages: parameters: NpmPackagingMode: 'dev' IsReleasePipeline: true - PoolName: 'aiinfra-Win-CPU-2022-web-beta' + PoolName: 'onnxruntime-Ubuntu2004-AMD-CPU' BuildStaticLib: true ExtraBuildArgs: '' + UseWebPoolName: true + WebCpuPoolName: 'Azure-Pipelines-EO-Windows2022-aiinfra' # This stage is to test if the combined build works on # o Windows ARM64 diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml new file mode 100644 index 0000000000000..aee42d3675087 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml @@ -0,0 +1,39 @@ +trigger: none + +parameters: + - name: enable_linux_gpu + type: boolean + default: true + - name: enable_windows_gpu + type: boolean + default: true + - name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + - name: cuda_version + type: string + default: '12.2' + values: + - 11.8 + - 12.2 + +resources: + repositories: + - repository: manylinux + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 + +stages: + - template: stages/py-cuda-packaging-stage.yml + parameters: + enable_linux_gpu: ${{ parameters.enable_linux_gpu }} + enable_windows_gpu: ${{ parameters.enable_windows_gpu }} + cmake_build_type: ${{ parameters.cmake_build_type }} + cuda_version: ${{ parameters.cuda_version }} \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml new file mode 100644 index 0000000000000..f3d68957d649c --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml @@ -0,0 +1,105 @@ +parameters: +- name: build_py_parameters + displayName: > + Extra parameters to pass to build.py. Don't put newlines in here. + type: string + default: '' + +- name: enable_linux_gpu + displayName: 'Whether Linux GPU package is built.' + type: boolean + default: true + +- name: enable_windows_gpu + displayName: 'Whether Windows GPU package is built.' + type: boolean + default: true + +# TODO: Now the Windows jobs use a different cmake build type. Consider to merge it. +- name: cmake_build_type + type: string + displayName: 'Linux packages cmake build type. Linux Only.' + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +- name: cuda_version + type: string + displayName: 'CUDA version. Windows Only.' + default: '12.2' + values: + - 11.8 + - 12.2 + +stages: +- stage: Python_Packaging + dependsOn: [] + variables: + - name: docker_base_image + ${{ if eq(parameters.cuda_version, '11.8') }}: + value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + ${{ if eq(parameters.cuda_version, '12.2') }}: + value: nvidia/cuda:12.2.2-cudnn8-devel-ubi8 + - name: linux_trt_version + ${{ if eq(parameters.cuda_version, '11.8') }}: + value: 8.6.1.6-1.cuda11.8 + ${{ if eq(parameters.cuda_version, '12.2') }}: + value: 8.6.1.6-1.cuda12.0 + - name: win_trt_home + ${{ if eq(parameters.cuda_version, '11.8') }}: + value: $(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8 + ${{ if eq(parameters.cuda_version, '12.2') }}: + value: $(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0 + - name: win_cuda_home + ${{ if eq(parameters.cuda_version, '11.8') }}: + value: $(Agent.TempDirectory)\v11.8 + ${{ if eq(parameters.cuda_version, '12.2') }}: + value: $(Agent.TempDirectory)\v12.2 + jobs: + - ${{ if eq(parameters.enable_windows_gpu, true) }}: + - template: ../templates/py-win-gpu.yml + parameters: + MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4' + PYTHON_VERSION: '3.8' + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_NAME: gpu + CudaVersion: ${{ parameters.cuda_version }} + + - template: ../templates/py-win-gpu.yml + parameters: + MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4' + PYTHON_VERSION: '3.9' + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_NAME: gpu + CudaVersion: ${{ parameters.cuda_version }} + + - template: ../templates/py-win-gpu.yml + parameters: + MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4' + PYTHON_VERSION: '3.10' + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_NAME: gpu + CudaVersion: ${{ parameters.cuda_version }} + + - template: ../templates/py-win-gpu.yml + parameters: + MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4' + PYTHON_VERSION: '3.11' + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_NAME: gpu + CudaVersion: ${{ parameters.cuda_version }} + + + - ${{ if eq(parameters.enable_linux_gpu, true) }}: + - template: ../templates/py-linux-gpu.yml + parameters: + arch: 'x86_64' + machine_pool: 'onnxruntime-Ubuntu2004-AMD-CPU' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + docker_base_image: ${{ variables.docker_base_image }} + trt_version: ${{ variables.linux_trt_version }} + cuda_version: ${{ parameters.cuda_version }} diff --git a/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml b/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml index 82a86e2ec8018..e664cf69dec76 100644 --- a/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml +++ b/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml @@ -38,6 +38,7 @@ steps: - ${{if eq(parameters.WithCache, true)}}: - script: | + set -e -x pushd '$(Build.SourcesDirectory)/cmake/external/emsdk' source ./emsdk_env.sh export PATH=$(Build.SourcesDirectory)/cmake/external/emsdk/:$PATH @@ -66,9 +67,9 @@ steps: EM_DIR: '$(Build.SourcesDirectory)/cmake/external/emsdk/upstream/emscripten' - ${{if eq(parameters.WithCache, false)}}: - - task: PythonScript@0 - displayName: '${{parameters.DisplayName}}' - inputs: - scriptPath: '$(Build.SourcesDirectory)/tools/ci_build/build.py' - arguments: ${{parameters.Arguments}} - workingDirectory: '$(Build.BinariesDirectory)' + - script: | + set -e -x + source $(Build.SourcesDirectory)/cmake/external/emsdk/emsdk_env.sh + cd '$(Build.BinariesDirectory)' + python3 '$(Build.SourcesDirectory)/tools/ci_build/build.py' ${{parameters.Arguments}} + displayName: ${{parameters.DisplayName}} diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml index 663ce4338c99f..5ee425405ac70 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml @@ -39,6 +39,15 @@ steps: mkdir $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib mkdir $(Build.BinariesDirectory)\${{parameters.artifactName}}\include + if exist $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_cuda.dll ( + echo "cuda context headers copied" + mkdir $(Build.BinariesDirectory)\${{parameters.artifactName}}\include\core\providers\cuda + copy $(Build.SourcesDirectory)\include\onnxruntime\core\providers\resource.h $(Build.BinariesDirectory)\${{parameters.artifactName}}\include\core\providers + copy $(Build.SourcesDirectory)\include\onnxruntime\core\providers\custom_op_context.h $(Build.BinariesDirectory)\${{parameters.artifactName}}\include\core\providers + copy $(Build.SourcesDirectory)\include\onnxruntime\core\providers\cuda\cuda_context.h $(Build.BinariesDirectory)\${{parameters.artifactName}}\include\core\providers\cuda + copy $(Build.SourcesDirectory)\include\onnxruntime\core\providers\cuda\cuda_resource.h $(Build.BinariesDirectory)\${{parameters.artifactName}}\include\core\providers\cuda + ) + echo "Directories created" copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_shared.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index a78f743c15347..f2deb2041e06e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.114 + version: 1.0.118 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.114 + version: 1.0.118 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml new file mode 100644 index 0000000000000..ff7f0957e94ba --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml @@ -0,0 +1,51 @@ +parameters: + - name: DownloadCUDA + type: boolean + default: false + - name: DownloadTRT + type: boolean + default: false + - name: CudaVersion + type: string + default: '11.8' + values: + - 11.8 + - 12.2 + +steps: + + - ${{ if eq(parameters.DownloadCUDA, true) }}: + - powershell: | + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.CudaVersion }} $(Agent.TempDirectory) + displayName: 'Download CUDA SDK v${{ parameters.CudaVersion }}' + - powershell: | + Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}\bin;$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}\extras\CUPTI\lib64" + displayName: 'Append CUDA SDK Directory to PATH' + - task: CmdLine@2 + inputs: + script: | + echo %PATH% + displayName: 'Print PATH' + + - ${{ if eq(parameters.DownloadTRT, true) }}: + - ${{ if eq(parameters.CudaVersion, '11.8') }}: + - powershell: | + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/local/TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8 $(Agent.TempDirectory) + displayName: 'Download TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8' + - powershell: | + Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8\lib" + displayName: 'Append TensorRT Directory to PATH' + + - ${{ if eq(parameters.CudaVersion, '12.2') }}: + - powershell: | + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/local/TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0 $(Agent.TempDirectory) + displayName: 'Download TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0' + - powershell: | + Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0\lib" + displayName: 'Append TensorRT Directory to PATH' + + - task: CmdLine@2 + inputs: + script: | + echo %PATH% + displayName: 'Print PATH' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml index 404699f705344..e40c4d0e95dc5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml @@ -4,6 +4,7 @@ parameters: - name: EnvSetupScript type: string + default: setup_env.bat - name: job_name_suffix type: string @@ -175,6 +176,11 @@ jobs: msbuildArguments: '-t:restore -p:OrtPackageId=$(OrtPackageId)' workingDirectory: '$(Build.SourcesDirectory)\csharp' + - script: | + python3 tools\ValidateNativeDelegateAttributes.py + displayName: 'Validate C# native delegates' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + - task: MSBuild@1 displayName: 'Build C#' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index f085686f4bc28..900d8e11d1006 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -46,7 +46,7 @@ jobs: variables: EnvSetupScript: setup_env.bat buildArch: x64 - CommonBuildArgs: '--parallel --config ${{ parameters.BuildConfig }} --skip_submodule_sync --build_wasm --use_xnnpack ${{ parameters.ExtraBuildArgs }}' + CommonBuildArgs: '--parallel --config ${{ parameters.BuildConfig }} --skip_submodule_sync --build_wasm ${{ parameters.ExtraBuildArgs }}' runCodesignValidationInjection: false TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] ORT_CACHE_DIR: $(Agent.TempDirectory)/ort_ccache @@ -81,9 +81,6 @@ jobs: versionSpec: '3.8' addToPath: true architecture: $(buildArch) - - task: NodeTool@0 - inputs: - versionSpec: '18.x' - template: download-deps.yml - task: PythonScript@0 @@ -93,13 +90,20 @@ jobs: arguments: --new_dir $(Build.BinariesDirectory)/deps workingDirectory: $(Build.BinariesDirectory) - - script: | - set -ex - cd '$(Build.SourcesDirectory)/cmake/external/emsdk' - ./emsdk install 3.1.47 ccache-git-emscripten-64bit - ./emsdk activate 3.1.47 ccache-git-emscripten-64bit - displayName: 'emsdk install and activate ccache for emscripten' - condition: eq('${{ parameters.WithCache }}', 'true') + - ${{if eq(parameters.WithCache, true)}}: + - script: | + set -ex + cd '$(Build.SourcesDirectory)/cmake/external/emsdk' + ./emsdk install 3.1.47 ccache-git-emscripten-64bit + ./emsdk activate 3.1.47 ccache-git-emscripten-64bit + displayName: 'emsdk install and activate ccache for emscripten' + - ${{if eq(parameters.WithCache, false)}}: + - script: | + set -ex + cd '$(Build.SourcesDirectory)/cmake/external/emsdk' + ./emsdk install 3.1.47 + ./emsdk activate 3.1.47 + displayName: 'emsdk install and activate ccache for emscripten' - template: build-linux-wasm-step.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml b/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml index abd8c94dabd91..e788e4b3dddaa 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-web-init-and-check.yml @@ -3,10 +3,6 @@ steps: npm ci workingDirectory: '$(Build.SourcesDirectory)/js' displayName: 'npm ci /js/' -- script: | - npm run lint - workingDirectory: '$(Build.SourcesDirectory)/js' - displayName: 'run ESLint without TS type populated' - script: | npm ci workingDirectory: '$(Build.SourcesDirectory)/js/common' @@ -19,6 +15,10 @@ steps: npm ci workingDirectory: '$(Build.SourcesDirectory)/js/web' displayName: 'npm ci /js/web/' +- script: | + npm run prebuild + workingDirectory: '$(Build.SourcesDirectory)/js/web' + displayName: 'run TypeScript type check in /js/web/' - script: | npm run lint workingDirectory: '$(Build.SourcesDirectory)/js' diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml index f5e5435cfaca0..fd2113502478a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml @@ -31,7 +31,7 @@ jobs: workspace: clean: all variables: - MACOSX_DEPLOYMENT_TARGET: '10.14' + MACOSX_DEPLOYMENT_TARGET: '11.0' ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] PROTO_CACHE_DIR: $(Pipeline.Workspace)/ccache_proto diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml index f68847afff379..8cc48aac7a3b9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml @@ -17,7 +17,24 @@ parameters: - Release - RelWithDebInfo - MinSizeRel - +- name: docker_base_image + type: string + default: 'nvidia/cuda:11.8.0-cudnn8-devel-ubi8' + values: + - nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + - nvidia/cuda:12.2.2-cudnn8-devel-ubi8 +- name: trt_version + type: string + default: '8.6.1.6-1.cuda11.8' + values: + - 8.6.1.6-1.cuda11.8 + - 8.6.1.6-1.cuda12.0 +- name: cuda_version + type: string + default: '11.8' + values: + - 11.8 + - 12.2 jobs: - job: Linux_py_GPU_Wheels_${{ parameters.arch }} timeoutInMinutes: 240 @@ -26,7 +43,13 @@ jobs: pool: ${{ parameters.machine_pool }} variables: # The build machine pool doesn't have dotnet, so it can't run CG. - skipComponentGovernanceDetection: true + - name: skipComponentGovernanceDetection + value: true + - name: extra_build_args + ${{ if ne(parameters.extra_build_arg, '') }}: + value: -x ${{ parameters.extra_build_arg }} + ${{ if eq(parameters.extra_build_arg, '') }}: + value: '' steps: - checkout: self clean: true @@ -40,12 +63,12 @@ jobs: Context: tools/ci_build/github/linux/docker DockerBuildArgs: " --network=host - --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 - --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg BASEIMAGE=${{ parameters.docker_base_image }} + --build-arg TRT_VERSION=${{ parameters.trt_version }} --build-arg BUILD_UID=$( id -u ) --build-arg PLATFORM=${{ parameters.arch }} " - Repository: onnxruntimecuda118xtrt86build${{ parameters.arch }} + Repository: onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }} - task: Bash@3 @@ -53,8 +76,7 @@ jobs: inputs: targetType: filePath filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh - # please check ONNXRUNTIME_CUDA_VERSION in tools/ci_build/github/linux/build_linux_arm64_python_package.sh - arguments: -i onnxruntimecuda118xtrt86build${{ parameters.arch }} -d "GPU" -c ${{ parameters.cmake_build_type }} -x "${{ parameters.extra_build_arg }}" + arguments: -i onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }} -d "GPU" -c ${{ parameters.cmake_build_type }} $(extra_build_args) - task: PublishBuildArtifacts@1 displayName: 'Publish Artifact: ONNXRuntime python wheel' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml index 0774c3350b9b1..db3782c69cf62 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml @@ -46,9 +46,17 @@ jobs: pool: ${{ parameters.machine_pool }} variables: # The build machine pool doesn't have dotnet, so it can't run CG. - skipComponentGovernanceDetection: true - ORT_CACHE_DIR: $(Agent.TempDirectory)/ort_ccache - TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + - name: skipComponentGovernanceDetection + value: true + - name: ORT_CACHE_DIR + value: $(Agent.TempDirectory)/ort_ccache + - name: TODAY + value: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + - name: extra_build_args + ${{ if ne(parameters.extra_build_arg, '') }}: + value: -x ${{ parameters.extra_build_arg }} + ${{ if eq(parameters.extra_build_arg, '') }}: + value: '' steps: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' @@ -82,7 +90,7 @@ jobs: inputs: targetType: filePath filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh - arguments: -i onnxruntimecpubuildpython${{ parameters.arch }} -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} -x "${{ parameters.extra_build_arg }}" + arguments: -i onnxruntimecpubuildpython${{ parameters.arch }} -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) ${{ if eq(parameters.with_cache, 'true') }}: env: ADDITIONAL_DOCKER_PARAMETER: "--volume $(ORT_CACHE_DIR):/cache -e CCACHE_DIR=/cache -e ORT_BUILD_WITH_CACHE=1" diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 1a67ace5e85fa..f2b91bbaacb89 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -335,7 +335,7 @@ stages: pool: vmImage: 'macOS-13' variables: - MACOSX_DEPLOYMENT_TARGET: '10.15' + MACOSX_DEPLOYMENT_TARGET: '11.0' strategy: matrix: Python38: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml index 919749cac15b6..501251eaff20f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml @@ -14,21 +14,32 @@ parameters: - name: ENV_SETUP_SCRIPT type: string + default: '' - name: BUILD_PY_PARAMETERS displayName: > Extra parameters to pass to build.py. Don't put newlines in here. type: string default: '' - +- name: CudaVersion + type: string + default: '11.8' + values: + - 11.8 + - 12.2 jobs: - job: Win_py_${{ parameters.EP_NAME }}_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }} timeoutInMinutes: 240 workspace: clean: all - pool: ${{ parameters.MACHINE_POOL }} + pool: + name: ${{ parameters.MACHINE_POOL }} +# demands: +# - ImageVersionOverride -equals 1.0.367516 variables: + GRADLE_OPTS: '-Dorg.gradle.daemon=false' VSGenerator: 'Visual Studio 17 2022' + CUDA_MODULE_LOADING: 'LAZY' steps: - checkout: self clean: true @@ -61,10 +72,21 @@ jobs: - template: download-deps.yml - - template: jobs/set-winenv.yml - parameters: - EnvSetupScript: ${{ parameters.ENV_SETUP_SCRIPT }} - DownloadCUDA: true + - ${{ if ne(parameters.ENV_SETUP_SCRIPT, '') }}: + - template: jobs/set-winenv.yml + parameters: + EnvSetupScript: ${{ parameters.ENV_SETUP_SCRIPT }} + ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}: + DownloadCUDA: true + + - ${{ if eq(parameters.ENV_SETUP_SCRIPT, '') }}: + - template: jobs/download_win_gpu_library.yml + parameters: + CudaVersion: ${{ parameters.CudaVersion }} + ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}: + DownloadCUDA: true + ${{ if contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt') }}: + DownloadTRT: true - task: PythonScript@0 displayName: 'Update deps.txt' diff --git a/tools/ci_build/github/azure-pipelines/templates/rocm.yml b/tools/ci_build/github/azure-pipelines/templates/rocm.yml index f41dc1c49f3d4..2e9e6c6b35a2e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/rocm.yml +++ b/tools/ci_build/github/azure-pipelines/templates/rocm.yml @@ -79,6 +79,7 @@ jobs: /onnxruntime_src/tools/ci_build/build.py \ --config ${{ parameters.BuildConfig }} \ --use_rocm \ + --use_migraphx \ --rocm_version=${{ parameters.RocmVersion }} \ --rocm_home=/opt/rocm \ --nccl_home=/opt/rocm \ diff --git a/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml b/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml index 8cc7f63a193cc..b8dba89b0b899 100644 --- a/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml @@ -3,7 +3,7 @@ parameters: - name: AndroidNdkVersion type: string - default: "26.0.10792818" # LTS version + default: "26.1.10909125" # LTS version steps: - bash: | diff --git a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml index 3f6c6af753a98..9982b36509b68 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml @@ -26,7 +26,7 @@ parameters: - name: WASMTemplate type: string - default: win-wasm-ci.yml + default: linux-wasm-ci.yml # parameter couldn't be compared by string, so add one boolean parameter. - name: UseWebPoolName type: boolean @@ -39,10 +39,10 @@ parameters: default: false - name: WebGpuPoolName type: string - default: '' + default: 'onnxruntime-Win2022-webgpu-A10' - name: WebCpuPoolName type: string - default: '' + default: 'onnxruntime-Win-CPU-2022-web' - name: ExtraBuildArgs displayName: 'Extra build command line arguments' @@ -65,7 +65,6 @@ stages: clean: all steps: - checkout: self - fetchDepth: 1 submodules: false - script: | git submodule sync -- cmake/external/onnx diff --git a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml index a5925d16564fe..79647cc5699c8 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml @@ -46,7 +46,7 @@ jobs: variables: EnvSetupScript: setup_env.bat buildArch: x64 - CommonBuildArgs: '--parallel --config ${{ parameters.BuildConfig }} --skip_submodule_sync --cmake_generator "MinGW Makefiles" --build_wasm --use_xnnpack ${{ parameters.ExtraBuildArgs }}' + CommonBuildArgs: '--parallel --config ${{ parameters.BuildConfig }} --skip_submodule_sync --cmake_generator "MinGW Makefiles" --build_wasm ${{ parameters.ExtraBuildArgs }}' runCodesignValidationInjection: false timeoutInMinutes: ${{ parameters.TimeoutInMinutes }} workspace: diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index 8c926619c797d..b7ec3305003d7 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -95,14 +95,22 @@ jobs: targetFolder: $(Build.SourcesDirectory)\js\web\lib\wasm\binding flattenFolders: true displayName: 'Binplace js files' + - script: | + npm i -g puppeteer + workingDirectory: '$(Build.SourcesDirectory)' + displayName: 'Use puppeteer to prepare Chrome for tests' + - script: | + FOR /F "tokens=* USEBACKQ" %%F IN (`where /r %HOMEDRIVE%%HOMEPATH%\.cache\puppeteer chrome.exe`) DO ( + SET var=%%F + ECHO found chrome.exe: %%F + ) + ECHO ##vso[task.setvariable variable=CHROME_BIN;]%var% + workingDirectory: '$(Build.SourcesDirectory)' + displayName: 'Set CHROME_BIN' - script: | npm ci workingDirectory: '$(Build.SourcesDirectory)\js' displayName: 'npm ci /js/' - - script: | - npm run lint - workingDirectory: '$(Build.SourcesDirectory)\js' - displayName: 'run ESLint without TS type populated' - script: | npm ci workingDirectory: '$(Build.SourcesDirectory)\js\common' @@ -115,6 +123,10 @@ jobs: npm ci workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'npm ci /js/web/' + - script: | + npm run prebuild + workingDirectory: '$(Build.SourcesDirectory)\js\web' + displayName: 'run TypeScript type check in /js/web/' - script: | npm run lint workingDirectory: '$(Build.SourcesDirectory)\js' @@ -157,44 +169,36 @@ jobs: errorActionPreference: stop displayName: 'Pack NPM packages' - script: | - npm test -- -e=edge -b=webgl,wasm,xnnpack + npm test -- -e=chrome -b=webgl,wasm,xnnpack workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (wasm,webgl,xnnpack backend)' condition: eq('${{ parameters.RunWebGpuTests }}', 'false') - retryCountOnTaskFailure: 3 - script: | - npm test -- -e=edge -b=webgl,wasm,xnnpack,webgpu $(webgpuCommandlineExtraFlags) + npm test -- -e=chrome -b=webgl,wasm,xnnpack,webgpu $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (ALL backends)' condition: eq('${{ parameters.RunWebGpuTests }}', 'true') - retryCountOnTaskFailure: 3 - script: | - npm test -- suite1 -e=edge -b=webgpu --io-binding=gpu-tensor $(webgpuCommandlineExtraFlags) + npm test -- suite1 -e=chrome -b=webgpu --io-binding=gpu-tensor $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-tensor)' condition: eq('${{ parameters.RunWebGpuTests }}', 'true') - # temporarily allow this test to fail, so that people are not blocked. - # investigation is ongoing for the root cause of the random failure (Edge crash). - # TODO: remove this line once the root cause is found and fixed. - continueOnError: true - script: | - npm test -- suite1 -e=edge -b=webgpu --io-binding=gpu-location $(webgpuCommandlineExtraFlags) + npm test -- suite1 -e=chrome -b=webgpu --io-binding=gpu-location $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-location)' condition: eq('${{ parameters.RunWebGpuTests }}', 'true') - retryCountOnTaskFailure: 3 - script: | - npm test -- --webgl-texture-pack-mode -b=webgl -e=edge + npm test -- --webgl-texture-pack-mode -b=webgl -e=chrome workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests - WebGL: packed mode' - retryCountOnTaskFailure: 3 - script: | - npm test -- --wasm-enable-proxy -b=wasm -e=edge + npm test -- --wasm-enable-proxy -b=wasm -e=chrome workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests - WebAssembly: proxy' condition: and(succeeded(), eq('${{ parameters.BuildConfig }}', 'Release')) - script: | - npm run test:e2e -- --browser=Edge_default + npm run test:e2e -- --browser=Chrome_default workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'E2E package consuming test' condition: and(succeeded(), eq('${{ parameters.BuildConfig }}', 'Release')) diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml index ed010b5619db5..d7ffc1828c943 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml @@ -40,7 +40,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'Debug' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --build_java --build_nodejs --build_wheel --disable_memleak_checker msbuildPlatform: x64 @@ -59,7 +58,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat buildArch: x64 # Compare to our Nuget packaging pipeline, this job has "--build_wheel" but doesn't have "--enable_lto --disable_rtti --use_telemetry --enable_wcos" # Python bindings use typeid so I can't disable RTTI here. If it causes a problem, we will need to split this job to two jobs. @@ -80,7 +78,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --build_wheel --use_dnnl --build_java msbuildPlatform: x64 @@ -101,7 +98,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --build_wheel --use_xnnpack msbuildPlatform: x64 @@ -120,7 +116,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --use_winml --enable_wcos --disable_rtti --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.22000.0 msbuildPlatform: x64 @@ -160,7 +155,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'Debug' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --enable_training --build_wheel --disable_memleak_checker msbuildPlatform: x64 @@ -179,7 +173,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --enable_training --build_wheel msbuildPlatform: x64 @@ -198,7 +191,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --enable_training_apis msbuildPlatform: x64 @@ -215,10 +207,17 @@ stages: - stage: x64_release_azure dependsOn: [] jobs: + - job: + steps: + - powershell: | + Write-Host "##vso[task.prependpath]$(Build.BinariesDirectory)\RelWithDebInfo\_deps\vcpkg-src\installed\x86-windows\bin" + $env:PATH + Write-Host "##vso[task.prependpath]$(Build.BinariesDirectory)\RelWithDebInfo\_deps\vcpkg-src\installed\x64-windows\bin" + $env:PATH + displayName: 'Append x64-windows and x86-windows to PATH' - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_azure.bat buildArch: x64 additionalBuildFlags: --use_azure --use_lock_free_queue msbuildPlatform: x64 @@ -231,3 +230,5 @@ stages: GenerateDocumentation: false WITH_CACHE: true MachinePool: 'onnxruntime-Win-CPU-2022' + + diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml index 15a786516396c..658c358aa4523 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml @@ -56,7 +56,7 @@ jobs: WithCache: True Today: $(TODAY) AdditionalKey: "gpu-tensorrt | $(BuildConfig)" - BuildPyArguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8" --cuda_version=11.8 --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75' + BuildPyArguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75' MsbuildArguments: $(MsbuildArguments) BuildArch: $(buildArch) Platform: 'x64' @@ -76,7 +76,7 @@ jobs: del wheel_filename_file python.exe -m pip install -q --upgrade %WHEEL_FILENAME% set PATH=$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig);%PATH% - python $(Build.SourcesDirectory)\tools\ci_build\build.py --config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8" --cuda_version=11.8 --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75 + python $(Build.SourcesDirectory)\tools\ci_build\build.py --config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75 workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)' displayName: 'Run tests' diff --git a/tools/ci_build/github/linux/build_linux_arm64_python_package.sh b/tools/ci_build/github/linux/build_linux_python_package.sh similarity index 78% rename from tools/ci_build/github/linux/build_linux_arm64_python_package.sh rename to tools/ci_build/github/linux/build_linux_python_package.sh index 516f320cd64c4..3c1c65c9a6862 100755 --- a/tools/ci_build/github/linux/build_linux_arm64_python_package.sh +++ b/tools/ci_build/github/linux/build_linux_python_package.sh @@ -15,9 +15,11 @@ do case "${parameter_Option}" in #GPU or CPU. d) BUILD_DEVICE=${OPTARG};; -p) PYTHON_EXES=(${OPTARG});; -x) EXTRA_ARG=(${OPTARG});; +p) PYTHON_EXES=${OPTARG};; +x) EXTRA_ARG=${OPTARG};; c) BUILD_CONFIG=${OPTARG};; +*) echo "Usage: $0 -d [-p ] [-x ] [-c ]" + exit 1;; esac done @@ -48,7 +50,7 @@ if [ "$ARCH" == "x86_64" ] && [ "$GCC_VERSION" -ge 9 ]; then fi echo "EXTRA_ARG:" -echo $EXTRA_ARG +echo "$EXTRA_ARG" if [ "$EXTRA_ARG" != "" ]; then BUILD_ARGS+=("$EXTRA_ARG") @@ -60,19 +62,19 @@ if [ "$ARCH" == "x86_64" ]; then fi if [ "$BUILD_DEVICE" == "GPU" ]; then + SHORT_CUDA_VERSION=$(echo $CUDA_VERSION | sed 's/\([[:digit:]]\+\.[[:digit:]]\+\)\.[[:digit:]]\+/\1/') #Enable CUDA and TRT EPs. - ONNXRUNTIME_CUDA_VERSION="11.8" - BUILD_ARGS+=("--nvcc_threads=1" "--use_cuda" "--use_tensorrt" "--cuda_version=$ONNXRUNTIME_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$ONNXRUNTIME_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$ONNXRUNTIME_CUDA_VERSION" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80") + BUILD_ARGS+=("--nvcc_threads=1" "--use_cuda" "--use_tensorrt" "--cuda_version=$SHORT_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80") fi export CFLAGS export CXXFLAGS for PYTHON_EXE in "${PYTHON_EXES[@]}" do - rm -rf /build/$BUILD_CONFIG + rm -rf /build/"$BUILD_CONFIG" ${PYTHON_EXE} /onnxruntime_src/tools/ci_build/build.py "${BUILD_ARGS[@]}" - cp /build/$BUILD_CONFIG/dist/*.whl /build/dist + cp /build/"$BUILD_CONFIG"/dist/*.whl /build/dist done which ccache && ccache -sv && ccache -z diff --git a/tools/ci_build/github/linux/copy_strip_binary.sh b/tools/ci_build/github/linux/copy_strip_binary.sh index 73444b35a6768..42973a8fcb5b8 100755 --- a/tools/ci_build/github/linux/copy_strip_binary.sh +++ b/tools/ci_build/github/linux/copy_strip_binary.sh @@ -56,6 +56,15 @@ cp $SOURCE_DIR/orttraining/orttraining/training_api/include/onnxruntime_training cp $SOURCE_DIR/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h $BINARY_DIR/$ARTIFACT_NAME/include +if [[ -f "$BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_cuda.so" ]]; then +# copy headers for context context used in custom ops +mkdir -p $BINARY_DIR/$ARTIFACT_NAME/include/core/providers/cuda +cp $SOURCE_DIR/include/onnxruntime/core/providers/custom_op_context.h $BINARY_DIR/$ARTIFACT_NAME/include/core/providers/custom_op_context.h +cp $SOURCE_DIR/include/onnxruntime/core/providers/resource.h $BINARY_DIR/$ARTIFACT_NAME/include/core/providers/resource.h +cp $SOURCE_DIR/include/onnxruntime/core/providers/cuda/cuda_context.h $BINARY_DIR/$ARTIFACT_NAME/include/core/providers/cuda/cuda_context.h +cp $SOURCE_DIR/include/onnxruntime/core/providers/cuda/cuda_resource.h $BINARY_DIR/$ARTIFACT_NAME/include/core/providers/cuda/cuda_resource.h +fi + # copy the README, licence and TPN cp $SOURCE_DIR/README.md $BINARY_DIR/$ARTIFACT_NAME/README.md cp $SOURCE_DIR/docs/Privacy.md $BINARY_DIR/$ARTIFACT_NAME/Privacy.md diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm index 19599c9f613d4..9e12fe8c75451 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm @@ -31,7 +31,7 @@ RUN yum install -y hipify-clang RUN yum -y install wget # rocm lib -RUN yum install -y miopen-hip-devel rocblas-devel rocrand-devel rccl-devel hipsparse-devel hipfft-devel hipcub-devel hipblas-devel rocthrust-devel +RUN yum install -y miopen-hip-devel rocblas-devel rocrand-devel rccl-devel hipsparse-devel hipfft-devel hipcub-devel hipblas-devel rocthrust-devel migraphx-devel ENV AUDITWHEEL_POLICY=${POLICY} AUDITWHEEL_ARCH=${PLATFORM} AUDITWHEEL_PLAT=${POLICY}_${PLATFORM} ENV LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 LANGUAGE=en_US.UTF-8 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 new file mode 100644 index 0000000000000..a36f60b87768d --- /dev/null +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 @@ -0,0 +1,180 @@ +ARG BASEIMAGE=nvidia/cuda:12.2.2-cudnn8-devel-ubi8 +ARG POLICY=manylinux2014 +ARG PLATFORM=x86_64 +ARG DEVTOOLSET_ROOTPATH= +ARG LD_LIBRARY_PATH_ARG= +ARG PREPEND_PATH= + +#We need both CUDA and manylinux. But the CUDA Toolkit End User License Agreement says NVIDIA CUDA Driver Libraries(libcuda.so, libnvidia-ptxjitcompiler.so) are only distributable in applications that meet this criteria: +#1. The application was developed starting from a NVIDIA CUDA container obtained from Docker Hub or the NVIDIA GPU Cloud, and +#2. The resulting application is packaged as a Docker container and distributed to users on Docker Hub or the NVIDIA GPU Cloud only. +#So we use CUDA as the base image then add manylinux on top of it. + +#Build manylinux2014 docker image begin +FROM $BASEIMAGE AS runtime_base +ARG POLICY +ARG PLATFORM +ARG DEVTOOLSET_ROOTPATH +ARG LD_LIBRARY_PATH_ARG +ARG PREPEND_PATH +LABEL maintainer="The ManyLinux project" + +ENV AUDITWHEEL_POLICY=${POLICY} AUDITWHEEL_ARCH=${PLATFORM} AUDITWHEEL_PLAT=${POLICY}_${PLATFORM} +ENV LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 LANGUAGE=en_US.UTF-8 +ENV DEVTOOLSET_ROOTPATH=${DEVTOOLSET_ROOTPATH} +ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG} +ENV PATH=${PREPEND_PATH}${PATH} +ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig + +# first copy the fixup mirrors script, keep the script around +COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors + +# setup entrypoint, this will wrap commands with `linux32` with i686 images +COPY build_scripts/install-entrypoint.sh \ + build_scripts/build_utils.sh \ + /build_scripts/ + +RUN /build_scripts/install-entrypoint.sh && rm -rf /build_scripts +COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint +ENTRYPOINT ["manylinux-entrypoint"] + +COPY build_scripts/install-runtime-packages.sh \ + build_scripts/build_utils.sh \ + /build_scripts/ +RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ + +COPY build_scripts/build_utils.sh /build_scripts/ + +COPY build_scripts/install-autoconf.sh /build_scripts/ +RUN export AUTOCONF_ROOT=autoconf-2.71 && \ + export AUTOCONF_HASH=431075ad0bf529ef13cb41e9042c542381103e80015686222b8a9d4abef42a1c && \ + export AUTOCONF_DOWNLOAD_URL=http://ftp.gnu.org/gnu/autoconf && \ + manylinux-entrypoint /build_scripts/install-autoconf.sh + +COPY build_scripts/install-automake.sh /build_scripts/ +RUN export AUTOMAKE_ROOT=automake-1.16.5 && \ + export AUTOMAKE_HASH=07bd24ad08a64bc17250ce09ec56e921d6343903943e99ccf63bbf0705e34605 && \ + export AUTOMAKE_DOWNLOAD_URL=http://ftp.gnu.org/gnu/automake && \ + manylinux-entrypoint /build_scripts/install-automake.sh + +COPY build_scripts/install-libtool.sh /build_scripts/ +RUN export LIBTOOL_ROOT=libtool-2.4.7 && \ + export LIBTOOL_HASH=04e96c2404ea70c590c546eba4202a4e12722c640016c12b9b2f1ce3d481e9a8 && \ + export LIBTOOL_DOWNLOAD_URL=http://ftp.gnu.org/gnu/libtool && \ + manylinux-entrypoint /build_scripts/install-libtool.sh + +COPY build_scripts/install-libxcrypt.sh /build_scripts/ +RUN export LIBXCRYPT_VERSION=4.4.28 && \ + export LIBXCRYPT_HASH=db7e37901969cb1d1e8020cb73a991ef81e48e31ea5b76a101862c806426b457 && \ + export LIBXCRYPT_DOWNLOAD_URL=https://github.com/besser82/libxcrypt/archive && \ + export PERL_ROOT=perl-5.34.0 && \ + export PERL_HASH=551efc818b968b05216024fb0b727ef2ad4c100f8cb6b43fab615fa78ae5be9a && \ + export PERL_DOWNLOAD_URL=https://www.cpan.org/src/5.0 && \ + manylinux-entrypoint /build_scripts/install-libxcrypt.sh + +FROM runtime_base AS build_base +COPY build_scripts/install-build-packages.sh /build_scripts/ +RUN manylinux-entrypoint /build_scripts/install-build-packages.sh + + +FROM build_base AS build_git +COPY build_scripts/build-git.sh /build_scripts/ +RUN export GIT_ROOT=git-2.36.2 && \ + export GIT_HASH=6dc2cdea5fb23d823ba4871cc23222c1db31dfbb6d6c6ff74c4128700df57c68 && \ + export GIT_DOWNLOAD_URL=https://www.kernel.org/pub/software/scm/git && \ + manylinux-entrypoint /build_scripts/build-git.sh + + +FROM build_base AS build_cpython +COPY build_scripts/build-sqlite3.sh /build_scripts/ +RUN export SQLITE_AUTOCONF_ROOT=sqlite-autoconf-3390200 && \ + export SQLITE_AUTOCONF_HASH=852be8a6183a17ba47cee0bbff7400b7aa5affd283bf3beefc34fcd088a239de && \ + export SQLITE_AUTOCONF_DOWNLOAD_URL=https://www.sqlite.org/2022 && \ + manylinux-entrypoint /build_scripts/build-sqlite3.sh + +COPY build_scripts/build-openssl.sh /build_scripts/ +RUN export OPENSSL_ROOT=openssl-1.1.1q && \ + export OPENSSL_HASH=d7939ce614029cdff0b6c20f0e2e5703158a489a72b2507b8bd51bf8c8fd10ca && \ + export OPENSSL_DOWNLOAD_URL=https://www.openssl.org/source && \ + manylinux-entrypoint /build_scripts/build-openssl.sh + +COPY build_scripts/build-cpython.sh /build_scripts/ + + +FROM build_cpython AS build_cpython38 +COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt +RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.8.13 + + +FROM build_cpython AS build_cpython39 +COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt +RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.9.13 + + +FROM build_cpython AS build_cpython310 +COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt +RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.10.5 + +FROM build_cpython AS build_cpython311 +COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt +RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.11.2 + +FROM build_cpython AS all_python +COPY build_scripts/install-pypy.sh \ + build_scripts/pypy.sha256 \ + build_scripts/finalize-python.sh \ + /build_scripts/ +RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.8 7.3.9 +RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.9 7.3.9 +COPY --from=build_cpython38 /opt/_internal /opt/_internal/ +COPY --from=build_cpython39 /opt/_internal /opt/_internal/ +COPY --from=build_cpython310 /opt/_internal /opt/_internal/ +COPY --from=build_cpython311 /opt/_internal /opt/_internal/ +RUN manylinux-entrypoint /build_scripts/finalize-python.sh + + +FROM runtime_base +COPY --from=build_git /manylinux-rootfs / +COPY --from=build_cpython /manylinux-rootfs / +COPY --from=all_python /opt/_internal /opt/_internal/ +COPY build_scripts/finalize.sh \ + build_scripts/python-tag-abi-tag.py \ + build_scripts/requirements3.8.txt \ + build_scripts/requirements3.9.txt \ + build_scripts/requirements3.10.txt \ + build_scripts/requirements3.11.txt \ + build_scripts/requirements-base-tools.txt \ + /build_scripts/ +COPY build_scripts/requirements-tools/* /build_scripts/requirements-tools/ +RUN manylinux-entrypoint /build_scripts/finalize.sh && rm -rf /build_scripts + +ENV SSL_CERT_FILE=/opt/_internal/certs.pem + +CMD ["/bin/bash"] + +#Build manylinux2014 docker image end +ARG PYTHON_VERSION=3.9 +ARG TORCH_VERSION=2.1.0 +ARG OPSET_VERSION=15 +ARG INSTALL_DEPS_EXTRA_ARGS + +#Add our own dependencies +ADD scripts /tmp/scripts +RUN cd /tmp/scripts && \ + /tmp/scripts/manylinux/install_centos.sh && \ + /tmp/scripts/install_os_deps.sh -d gpu $INSTALL_DEPS_EXTRA_ARGS && \ + /tmp/scripts/install_rust.sh + +ENV PATH="/root/.cargo/bin/:$PATH" + +RUN /tmp/scripts/install_ninja.sh && \ + /tmp/scripts/install_python_deps.sh -d gpu -v 12.2 -p $PYTHON_VERSION -h $TORCH_VERSION $INSTALL_DEPS_EXTRA_ARGS && \ + rm -rf /tmp/scripts + +ARG BUILD_UID=1001 +ARG BUILD_USER=onnxruntimedev +RUN adduser --uid $BUILD_UID $BUILD_USER +WORKDIR /home/$BUILD_USER +USER $BUILD_USER +ENV PATH /usr/local/dotnet:$PATH +ENV ORTMODULE_ONNX_OPSET_VERSION=$OPSET_VERSION diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh index 3bca6413100a2..da8a45e00cc90 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh @@ -19,7 +19,9 @@ fi export ONNX_ML=1 export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" +# This may install PyTorch, which will be overrided by the PyTorch local build below. /opt/python/cp39-cp39/bin/python3.9 -m pip install transformers + # beartype is installed here so that onnxscript installation step won't # install a version PyTorch doesn't like. Once beartype fixes this problem. # We can remove this line. diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.1.0_cu12.2/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.1.0_cu12.2/requirements.txt new file mode 100644 index 0000000000000..152a17db90366 --- /dev/null +++ b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.1.0_cu12.2/requirements.txt @@ -0,0 +1,7 @@ +--pre +-f https://download.pytorch.org/whl/torch_stable.html +torch==2.1.0+cu121 +torchvision==0.16.0+cu121 +torchtext==0.16.0 +packaging==23.1 +setuptools>=68.2.2 diff --git a/tools/ci_build/github/linux/run_python_dockerbuild.sh b/tools/ci_build/github/linux/run_python_dockerbuild.sh index 18ac6482827f9..ff2ce6f7ff231 100755 --- a/tools/ci_build/github/linux/run_python_dockerbuild.sh +++ b/tools/ci_build/github/linux/run_python_dockerbuild.sh @@ -9,24 +9,32 @@ i) DOCKER_IMAGE=${OPTARG};; d) DEVICE=${OPTARG};; x) BUILD_EXTR_PAR=${OPTARG};; c) BUILD_CONFIG=${OPTARG};; +*) echo "Usage: $0 -i -d [-x ] [-c ]" + exit 1;; esac done -mkdir -p $HOME/.onnx +mkdir -p "${HOME}/.onnx" +DOCKER_SCRIPT_OPTIONS="-d ${DEVICE} -c ${BUILD_CONFIG}" + +if [ "${BUILD_EXTR_PAR}" != "" ] ; then + DOCKER_SCRIPT_OPTIONS+=" -x ${BUILD_EXTR_PAR}" +fi + docker run --rm \ --volume /data/onnx:/data/onnx:ro \ - --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src \ - --volume $BUILD_BINARIESDIRECTORY:/build \ + --volume "${BUILD_SOURCESDIRECTORY}:/onnxruntime_src" \ + --volume "${BUILD_BINARIESDIRECTORY}:/build" \ --volume /data/models:/build/models:ro \ - --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + --volume "${HOME}/.onnx:/home/onnxruntimedev/.onnx" \ -w /onnxruntime_src \ -e NIGHTLY_BUILD \ -e BUILD_BUILDNUMBER \ $ADDITIONAL_DOCKER_PARAMETER \ - $DOCKER_IMAGE tools/ci_build/github/linux/build_linux_arm64_python_package.sh -d $DEVICE -c $BUILD_CONFIG -x $BUILD_EXTR_PAR + $DOCKER_IMAGE tools/ci_build/github/linux/build_linux_python_package.sh $DOCKER_SCRIPT_OPTIONS -sudo rm -rf $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/onnxruntime $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/pybind11 \ - $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/models $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/_deps \ - $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/CMakeFiles -cd $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG -find -executable -type f > $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/perms.txt +sudo rm -rf "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/onnxruntime" "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/pybind11" \ + "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/models" "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/_deps" \ + "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/CMakeFiles" +cd "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}" +find -executable -type f > "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/perms.txt" diff --git a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile index 412bc00d02778..2ec826fc8fd8c 100644 --- a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile @@ -80,7 +80,6 @@ RUN ln -sf /usr/lib/x86_64-linux-gnu/libstdc++.so.6 ${CONDA_ENVIRONMENT_PATH}/bi RUN pip install torch==2.0.1 torchvision==0.15.2 -f https://repo.radeon.com/rocm/manylinux/rocm-rel-${ROCM_VERSION}/ && \ pip install torch-ort --no-dependencies - ##### Install Cupy to decrease CPU utilization # Install non dev openmpi RUN wget https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.5.tar.bz2 && \ @@ -130,6 +129,9 @@ RUN pip install \ pytest-xdist \ pytest-rerunfailures +# Install migraphx +RUN apt update && apt install -y migraphx + ENV ORTMODULE_ONNX_OPSET_VERSION=15 ARG BUILD_UID=1001 diff --git a/tools/ci_build/github/windows/setup_env_azure.bat b/tools/ci_build/github/windows/setup_env_azure.bat deleted file mode 100644 index 44ba34b0bf23a..0000000000000 --- a/tools/ci_build/github/windows/setup_env_azure.bat +++ /dev/null @@ -1,4 +0,0 @@ -REM Copyright (c) Microsoft Corporation. All rights reserved. -REM Licensed under the MIT License. -set PATH=%cd%\RelWithDebInfo\_deps\vcpkg-src\installed\x64-windows\bin;%cd%\RelWithDebInfo\_deps\vcpkg-src\installed\x86-windows\bin;%PATH% -set GRADLE_OPTS=-Dorg.gradle.daemon=false diff --git a/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py b/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py index 113b5398f3981..9eccb7c36455f 100644 --- a/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py +++ b/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py @@ -10,9 +10,8 @@ import sys import onnx -from onnx import shape_inference -from ..onnx_model_utils import get_opsets_imported +from ..onnx_model_utils import ModelProtoWithShapeInfo, get_opsets_imported from ..reduced_build_config_parser import parse_config cpp_to_tensorproto_type = { @@ -265,15 +264,13 @@ def run_check(model_path: pathlib.Path, mobile_pkg_build_config: pathlib.Path, l ) model_file = model_path.resolve(strict=True) - model = onnx.load(str(model_file)) # we need to run shape inferencing to populate that type info for node outputs. # we will get warnings if the model uses ORT contrib ops (ONNX does not have shape inferencing for those), # and shape inferencing will be lost downstream of those. # TODO: add support for checking ORT format model as it will have full type/shape info for all nodes - model_with_type_info = shape_inference.infer_shapes(model) - - return run_check_with_model(model_with_type_info, mobile_pkg_build_config, logger) + model_wrapper = ModelProtoWithShapeInfo(model_file) + return run_check_with_model(model_wrapper.model_with_shape_info, mobile_pkg_build_config, logger) def main(): diff --git a/tools/python/util/mobile_helpers/usability_checker.py b/tools/python/util/mobile_helpers/usability_checker.py index f8b0bfe707ead..dcb3451a5e0fa 100644 --- a/tools/python/util/mobile_helpers/usability_checker.py +++ b/tools/python/util/mobile_helpers/usability_checker.py @@ -13,6 +13,7 @@ import onnx from ..onnx_model_utils import ( + ModelProtoWithShapeInfo, get_producer_consumer_maps, is_fixed_size_tensor, iterate_graph_per_graph_func, @@ -464,9 +465,9 @@ def check_shapes(graph: onnx.GraphProto, logger: Optional[logging.Logger] = None return dynamic_inputs, num_dynamic_values -def checker(model_path, logger: logging.Logger): - model = onnx.load(model_path) - model_with_shape_info = onnx.shape_inference.infer_shapes(model) +def checker(model_path: pathlib.Path, logger: logging.Logger): + model_with_shape_info_wrapper = ModelProtoWithShapeInfo(model_path) + model_with_shape_info = model_with_shape_info_wrapper.model_with_shape_info # create lookup map for efficiency value_to_shape = {} @@ -541,10 +542,10 @@ def analyze_model(model_path: pathlib.Path, skip_optimize: bool = False, logger: with tempfile.TemporaryDirectory() as tmp: if not skip_optimize: tmp_path = pathlib.Path(tmp) / model_path.name - optimize_model(model_path, tmp_path) + optimize_model(model_path, tmp_path, use_external_initializers=True) model_path = tmp_path - try_eps = checker(str(model_path.resolve(strict=True)), logger) + try_eps = checker(model_path.resolve(strict=True), logger) return try_eps diff --git a/tools/python/util/onnx_model_utils.py b/tools/python/util/onnx_model_utils.py index e662d1623f8bd..5c970430a3a82 100644 --- a/tools/python/util/onnx_model_utils.py +++ b/tools/python/util/onnx_model_utils.py @@ -95,6 +95,7 @@ def optimize_model( output_path: pathlib.Path, level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC, log_level: int = 3, + use_external_initializers: bool = False, ): """ Optimize an ONNX model using ONNX Runtime to the specified level @@ -103,12 +104,25 @@ def optimize_model( :param level: onnxruntime.GraphOptimizationLevel to use. Default is ORT_ENABLE_BASIC. :param log_level: Log level. Defaults to Error (3) so we don't get output about unused initializers being removed. Warning (2) or Info (1) may be desirable in some scenarios. + :param use_external_initializers: Set flag to write initializers to an external file. Required if model > 2GB. + Requires onnxruntime 1.17+ """ so = ort.SessionOptions() so.optimized_model_filepath = str(output_path.resolve()) so.graph_optimization_level = level so.log_severity_level = log_level + # save using external initializers so models > 2 GB are handled + if use_external_initializers: + major, minor, rest = ort.__version__.split(".", 3) + if (int(major), int(minor)) >= (1, 17): + so.add_session_config_entry("session.optimized_model_external_initializers_file_name", "external_data.pb") + else: + raise ValueError( + "ONNX Runtime 1.17 or higher required to save initializers as external data when optimizing model. " + f"Current ONNX Runtime version is {ort.__version__}" + ) + # create session to optimize. this will write the updated model to output_path _ = ort.InferenceSession(str(model_path.resolve(strict=True)), so, providers=["CPUExecutionProvider"]) @@ -366,3 +380,34 @@ def get_optimization_level(level): return ort.GraphOptimizationLevel.ORT_ENABLE_ALL raise ValueError("Invalid optimization level of " + level) + + +class ModelProtoWithShapeInfo: + """ + Class to load an ONNX model and run shape inferencing on it to populate the ValueInfo. + The model_with_shape_info property will contain the updated model. + If the model is > 2GB and uses external data a temporary file is required to run shape inferencing successfully. + This helper class handles automatic removal of the temporary file. + """ + + def __init__(self, model_path: pathlib.Path): + """ + :param model_path: Path to ONNX model to load and run shape inferencing on. + """ + + self.model_path = model_path + + model = onnx.load(str(model_path)) + self.model_with_shape_info = onnx.shape_inference.infer_shapes(model, strict_mode=True) + + # ONNX has a silent failure from the call to infer_shapes when the model is > 2GB. + # We detect that by checking the nodes in the returned model. + self._tmp_model_path = None + if len(model.graph.node) > 0 and len(self.model_with_shape_info.graph.node) == 0: + self._tmp_model_path = pathlib.Path(model_path).with_suffix(".temp_with_shapeinf.onnx") + onnx.shape_inference.infer_shapes_path(str(model_path), str(self._tmp_model_path), strict_mode=True) + self.model_with_shape_info = onnx.load(str(self._tmp_model_path)) + + def __del__(self): + if self._tmp_model_path: + self._tmp_model_path.unlink(missing_ok=True)