diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 91a80bf4b2578..12fbb291c3a70 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -286,7 +286,7 @@ "component": { "type": "git", "git": { - "commitHash": "c4f6b8c6bc94ff69048492fb34df0dfaf1983933", + "commitHash": "6f47420213f757831fae65c686aa471749fa8d60", "repositoryUrl": "https://github.com/NVIDIA/cutlass.git" }, "comments": "cutlass" diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index e82219a0aff64..5796db03fed7c 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -114,9 +114,7 @@ option(onnxruntime_ENABLE_LTO "Enable link time optimization" OFF) option(onnxruntime_CROSS_COMPILING "Cross compiling onnx runtime" OFF) option(onnxruntime_GCOV_COVERAGE "Compile with options necessary to run code coverage" OFF) option(onnxruntime_DONT_VECTORIZE "Do not vectorize operations in Eigen" OFF) - -#It's preferred to turn it OFF when onnxruntime is dynamically linked to PROTOBUF. But Tensort always required the full version of protobuf. -cmake_dependent_option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF "NOT onnxruntime_USE_TENSORRT" ON) +option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF) option(tensorflow_C_PACKAGE_PATH "Path to tensorflow C package installation dir") option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in language other than cpp" OFF) option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump debug information about node inputs and outputs when executing the model." OFF) diff --git a/cmake/deps.txt b/cmake/deps.txt index 33cf5be76a200..49142372ab86e 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -51,7 +51,7 @@ pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/959002f82d7962a473d8b re2;https://github.com/google/re2/archive/refs/tags/2022-06-01.zip;aa77313b76e91b531ee7f3e45f004c6a502a5374 safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 -cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.0.0.zip;0f95b3c1fc1bd1175c4a90b2c9e39074d1bccefd +cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/a4f72a314a85732ed67d5aa8d1088d207a7e0e61.zip;f57357ab6d300e207a632d034ebc8aa036a090d9 diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index 8c5d81d638ced..983eecdd88235 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -4,7 +4,6 @@ if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTIO cutlass URL ${DEP_URL_cutlass} URL_HASH SHA1=${DEP_SHA1_cutlass} - PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/cutlass/cutlass.patch ) FetchContent_GetProperties(cutlass) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index a62b1b259d109..04efa5c2b4f6d 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -33,6 +33,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qpostprocessor.cpp ${MLAS_SRC_DIR}/qlgavgpool.cpp ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp + ${MLAS_SRC_DIR}/sqnbitgemm.cpp ) if (NOT onnxruntime_ORT_MINIMAL_BUILD) @@ -68,6 +69,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ) set(mlas_platform_preprocess_srcs @@ -334,6 +336,7 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp ) if (NOT APPLE) set(mlas_platform_srcs diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index a9a78668b4810..345ef2b504aa4 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -339,9 +339,6 @@ configure_file(${ONNXRUNTIME_ROOT}/python/_pybind_state.py.in ${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py) if (onnxruntime_ENABLE_TRAINING) - file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/python/deprecated/*.py" - ) file(GLOB onnxruntime_python_root_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/*.py" ) @@ -419,10 +416,6 @@ if (onnxruntime_ENABLE_TRAINING) "${ORTTRAINING_SOURCE_DIR}/python/training/onnxblock/optim/*" ) endif() -else() - file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/python/training/*.py" - ) endif() if (onnxruntime_BUILD_UNIT_TESTS) @@ -443,6 +436,9 @@ if (onnxruntime_BUILD_UNIT_TESTS) file(GLOB onnxruntime_python_transformers_testdata_whisper CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/whisper/*.onnx" ) + file(GLOB onnxruntime_python_transformers_testdata_conformer CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/conformer/*.onnx" + ) endif() file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS @@ -556,6 +552,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models/whisper COMMAND ${CMAKE_COMMAND} -E make_directory $/eager_test + COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models/conformer COMMAND ${CMAKE_COMMAND} -E copy ${ONNXRUNTIME_ROOT}/__init__.py $/onnxruntime/ @@ -577,9 +574,6 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py $/onnxruntime/capi/ - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_capi_training_srcs} - $/onnxruntime/capi/training/ COMMAND ${CMAKE_COMMAND} -E copy $ $/onnxruntime/capi/ @@ -711,6 +705,9 @@ if (onnxruntime_BUILD_UNIT_TESTS) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_testdata_whisper} $/transformers/test_data/models/whisper/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_testdata_conformer} + $/transformers/test_data/models/conformer/ ) endif() @@ -750,9 +747,6 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils/data/ COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils/hooks/ - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_capi_training_srcs} - $/onnxruntime/capi/training/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_root_srcs} $/onnxruntime/training/ diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index ce27bf756f247..980bd59b22c3f 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -51,6 +51,7 @@ set(contrib_ops_excluded_files "math/gemm_float8.cc" "math/gemm_float8.cu" "math/gemm_float8.h" + "moe/*" "quantization/attention_quantization.cc" "quantization/attention_quantization.h" "quantization/attention_quantization_impl.cu" diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index bdb0230a8ebd0..a52e941b235b4 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -906,7 +906,7 @@ if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set_target_properties(onnxruntime_test_all PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js) - set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s ALLOW_MEMORY_GROWTH=1 --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1 -s DEMANGLE_SUPPORT=1") + set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=4294967296 --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1 -s DEMANGLE_SUPPORT=1") if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY LINK_FLAGS " -s DEFAULT_PTHREAD_STACK_SIZE=131072 -s PROXY_TO_PTHREAD=1") endif() diff --git a/cmake/patches/cutlass/cutlass.patch b/cmake/patches/cutlass/cutlass.patch deleted file mode 100644 index bda1de8b46916..0000000000000 --- a/cmake/patches/cutlass/cutlass.patch +++ /dev/null @@ -1,92 +0,0 @@ -diff --git a/include/cute/numeric/complex.hpp b/include/cute/numeric/complex.hpp -index 3790ebd3..cf727d09 100644 ---- a/include/cute/numeric/complex.hpp -+++ b/include/cute/numeric/complex.hpp -@@ -41,10 +41,14 @@ - // With CUDA 11.4, builds show spurious "-Wconversion" warnings - // on line 656 of thrust/detail/type_traits.h. - // These pragmas suppress the warnings. -+#ifdef __GNUC__ - #pragma GCC diagnostic push - #pragma GCC diagnostic ignored "-Wconversion" -+#endif - #include -+#ifdef __GNUC__ - #pragma GCC diagnostic pop -+#endif - - #include - -diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h -index 59aec46a..8f2a913a 100644 ---- a/include/cutlass/functional.h -+++ b/include/cutlass/functional.h -@@ -89,7 +89,7 @@ struct multiplies { - } - }; - --#if defined(__CUDA_ARCH__) -+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - /// Partial specializations needed when __CUDA_NO_HALF2_OPERATORS__ is set - template<> - struct plus<__half2> { -@@ -143,12 +143,12 @@ struct multiplies<__half> { - - - // Maximum with nan propogation --// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN -+// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN - template - struct maximum_with_nan_propogation { - CUTLASS_HOST_DEVICE - T operator()(T const &lhs, T const &rhs) const { -- return lhs > rhs or std::isnan(lhs) ? lhs : rhs; -+ return lhs > rhs or isnan(lhs) ? lhs : rhs; - } - }; - -@@ -160,7 +160,7 @@ struct maximum_with_nan_propogation { - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs)); - #else -- res = lhs > rhs or std::isnan(lhs) ? lhs : rhs; -+ res = lhs > rhs or isnan(lhs) ? lhs : rhs; - #endif - return res; - } -@@ -233,7 +233,7 @@ struct negate { - } - }; - --/// Greater equal -+/// Greater equal - template - struct greater_equal { - CUTLASS_HOST_DEVICE -@@ -242,7 +242,7 @@ struct greater_equal { - } - }; - --/// Greater -+/// Greater - template - struct greater { - CUTLASS_HOST_DEVICE -@@ -251,7 +251,7 @@ struct greater { - } - }; - --/// Less equal -+/// Less equal - template - struct less_equal { - CUTLASS_HOST_DEVICE -@@ -260,7 +260,7 @@ struct less_equal { - } - }; - --/// Less -+/// Less - template - struct less { - CUTLASS_HOST_DEVICE diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 169737c9138c3..8565ffbb6c379 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -54,6 +54,7 @@ Do not modify directly.* * com.microsoft.MatMulIntegerToFloat * com.microsoft.MatMulNBits * com.microsoft.MaxpoolWithMask + * com.microsoft.MoE * com.microsoft.MulInteger * com.microsoft.MultiHeadAttention * com.microsoft.MurmurHash3 @@ -2384,7 +2385,7 @@ This version of the operator has been available since version 1 of the 'com.micr Group Query Self/Cross Attention. - Supports different number of heads for q and kv. + Supports different number of heads for q and kv. Only supports causal or local attention. #### Version @@ -2395,6 +2396,8 @@ This version of the operator has been available since version 1 of the 'com.micr
kv_num_heads : int (required)
Number of attention heads for k and v
+
local_window_size : int
+
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
num_heads : int (required)
Number of attention heads for q
scale : float
@@ -2904,6 +2907,58 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.MoE** + + Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1, + GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, and Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) + usually uses top 32 experts. + + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation_type : string
+
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+
k : int
+
Number of top experts to select from expert pool
+
+ +#### Inputs (4 - 6) + +
+
input : T
+
2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)
+
router_probs : T
+
2D input tensor with shape (num_rows, num_experts)
+
fc1_experts_weights : T
+
3D input tensor with shape (num_experts, hidden_size, inter_size)
+
fc2_experts_weights : T
+
3D input tensor with shape (num_experts, inter_size, hidden_size)
+
fc1_experts_bias (optional) : T
+
2D optional input tensor with shape (num_experts, inter_size)
+
fc2_experts_bias (optional) : T
+
2D optional input tensor with shape (num_experts, hidden_size)
+
+ +#### Outputs + +
+
output : T
+
2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16)
+
Constrain input and output types to float or float16 tensors.
+
+ + ### **com.microsoft.MulInteger** Performs element-wise binary quantized multiplication (with Numpy-style broadcasting support). @@ -4968,7 +5023,7 @@ This version of the operator has been available since version 1 of the 'com.micr
input : T
-
3D tensor with shape (batch_size, sequence_length, hidden_size)
+
3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)
position_ids : M
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
cos_cache : T
@@ -4981,7 +5036,7 @@ This version of the operator has been available since version 1 of the 'com.micr
output : T
-
3D tensor with shape (batch_size, sequence_length, hidden_size)
+
tensor with same shape as input.
#### Type Constraints diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 8e546b30aa4cb..26b5ebbdbec36 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -842,6 +842,7 @@ Do not modify directly.* |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc2_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/include/onnxruntime/core/framework/op_node_proto_helper.h b/include/onnxruntime/core/framework/op_node_proto_helper.h index 700e1edc0cb7d..e7ac01947af41 100644 --- a/include/onnxruntime/core/framework/op_node_proto_helper.h +++ b/include/onnxruntime/core/framework/op_node_proto_helper.h @@ -10,20 +10,6 @@ #include "core/common/gsl.h" #endif -#ifdef __has_attribute -#define ORT_HAVE_ATTRIBUTE(x) __has_attribute(x) -#else -#define ORT_HAVE_ATTRIBUTE(x) 0 -#endif - -#if ORT_HAVE_ATTRIBUTE(nodiscard) -#define MUST_USE_RESULT [[nodiscard]] -#elif defined(__clang__) && ORT_HAVE_ATTRIBUTE(warn_unused_result) -#define MUST_USE_RESULT __attribute__((warn_unused_result)) -#else -#define MUST_USE_RESULT -#endif - class IMLOpKernel; namespace onnxruntime { @@ -43,14 +29,26 @@ class OpNodeProtoHelper { Call this function for a required attribute or when a default value for an optional attribute is specified in the op schema */ template - MUST_USE_RESULT Status GetAttr(const std::string& name, T* value) const; + Status GetAttr(const std::string& name, T* value) const; + + /** + Get a single attribute + Call this function for a required attribute or when a default value for an optional attribute is specified in the op schema + Throws if an attribute with the specified type doesn't exist + */ + template + [[nodiscard]] T GetAttr(const std::string& name) const { + T value; + ORT_THROW_IF_ERROR(GetAttr(name, &value)); + return value; + } /** Get a single attribute Call this function only when a default value for an optional attribute isn't specified in the op schema */ template - T GetAttrOrDefault(const std::string& name, const T& default_value) const { + [[nodiscard]] T GetAttrOrDefault(const std::string& name, const T& default_value) const { T tmp; return GetAttr(name, &tmp).IsOK() ? tmp : default_value; } @@ -70,7 +68,8 @@ class OpNodeProtoHelper { Call this function only when a default value for an optional attribute isn't specified in the op schema */ template - MUST_USE_RESULT std::vector GetAttrsOrDefault(const std::string& name, const std::vector& default_value = std::vector{}) const { + [[nodiscard]] std::vector GetAttrsOrDefault(const std::string& name, + const std::vector& default_value = {}) const { std::vector tmp; return GetAttrs(name, tmp).IsOK() ? tmp : default_value; } @@ -87,11 +86,12 @@ class OpNodeProtoHelper { /// Attribute data in a span, out parameter /// Status template - MUST_USE_RESULT Status GetAttrsAsSpan(const std::string& name, gsl::span& values) const; + Status GetAttrsAsSpan(const std::string& name, gsl::span& values) const; - MUST_USE_RESULT Status GetAttrs(const std::string& name, TensorShapeVector& out) const; + Status GetAttrs(const std::string& name, TensorShapeVector& out) const; - MUST_USE_RESULT TensorShapeVector GetAttrsOrDefault(const std::string& name, const TensorShapeVector& default_value = TensorShapeVector{}) const { + [[nodiscard]] TensorShapeVector GetAttrsOrDefault(const std::string& name, + const TensorShapeVector& default_value = {}) const { TensorShapeVector tmp; return GetAttrs(name, tmp).IsOK() ? tmp : default_value; } @@ -100,43 +100,43 @@ class OpNodeProtoHelper { Get repeated attributes */ template - MUST_USE_RESULT Status GetAttrs(const std::string& name, std::vector& values) const; + Status GetAttrs(const std::string& name, std::vector& values) const; template - MUST_USE_RESULT Status GetAttrs(const std::string& name, gsl::span values) const; + Status GetAttrs(const std::string& name, gsl::span values) const; - MUST_USE_RESULT Status GetAttrsStringRefs(const std::string& name, - std::vector>& refs) const; + Status GetAttrsStringRefs(const std::string& name, + std::vector>& refs) const; - uint32_t GetPrimitiveAttrElementCount(ONNX_NAMESPACE::AttributeProto_AttributeType type, - const std::string& name) const noexcept; + [[nodiscard]] uint32_t GetPrimitiveAttrElementCount(ONNX_NAMESPACE::AttributeProto_AttributeType type, + const std::string& name) const noexcept; - bool HasPrimitiveAttribute(ONNX_NAMESPACE::AttributeProto_AttributeType type, - const std::string& name) const noexcept; + [[nodiscard]] bool HasPrimitiveAttribute(ONNX_NAMESPACE::AttributeProto_AttributeType type, + const std::string& name) const noexcept; - uint32_t GetInputCount() const { + [[nodiscard]] uint32_t GetInputCount() const { return gsl::narrow_cast(impl_->getNumInputs()); } - uint32_t GetOutputCount() const { + [[nodiscard]] uint32_t GetOutputCount() const { return gsl::narrow_cast(impl_->getNumOutputs()); } - const ONNX_NAMESPACE::TypeProto* GetInputType(size_t index) const { + [[nodiscard]] const ONNX_NAMESPACE::TypeProto* GetInputType(size_t index) const { return impl_->getInputType(index); } - const ONNX_NAMESPACE::TypeProto* GetOutputType(size_t index) const { + [[nodiscard]] const ONNX_NAMESPACE::TypeProto* GetOutputType(size_t index) const { // Work around lack of a const method from the onnx InferenceContext interface return const_cast(impl_)->getOutputType(index); } // Try to query an attribute, returning nullptr if it doesn't exist - const ONNX_NAMESPACE::AttributeProto* TryGetAttribute(const std::string& name) const { + [[nodiscard]] const ONNX_NAMESPACE::AttributeProto* TryGetAttribute(const std::string& name) const { return impl_->getAttribute(name); } - const ONNX_NAMESPACE::AttributeProto* GetAttribute(const std::string& name) const { + [[nodiscard]] const ONNX_NAMESPACE::AttributeProto* GetAttribute(const std::string& name) const { const ONNX_NAMESPACE::AttributeProto* attr = TryGetAttribute(name); ORT_ENFORCE(attr != nullptr); return attr; diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h index 443710884743a..0c0af16d4e20c 100644 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -399,6 +399,15 @@ struct TensorArray : public ArgBase { using Variadic = TensorArray; +/* +Note: +OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core. +The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so: +1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierachy. +2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp, + hence memory could still be recycled properly. +Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety. +*/ struct OrtLiteCustomOp : public OrtCustomOp { using ConstOptionalFloatTensor = std::optional&>; using OptionalFloatTensor = std::optional>; @@ -774,10 +783,13 @@ struct OrtLiteCustomOp : public OrtCustomOp { OrtLiteCustomOp(const char* op_name, const char* execution_provider, - int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), - execution_provider_(execution_provider), - start_ver_(start_ver), - end_ver_(end_ver) { + ShapeInferFn shape_infer_fn, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), + execution_provider_(execution_provider), + shape_infer_fn_(shape_infer_fn), + start_ver_(start_ver), + end_ver_(end_ver) { OrtCustomOp::version = ORT_API_VERSION; OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast(op)->op_name_.c_str(); }; @@ -858,8 +870,13 @@ struct OrtLiteCustomOp : public OrtCustomOp { std::vector input_types_; std::vector output_types_; + ShapeInferFn shape_infer_fn_ = {}; + int start_ver_ = 1; int end_ver_ = MAX_CUSTOM_OP_END_VER; + + void* compute_fn_ = {}; + void* compute_fn_return_status_ = {}; }; //////////////////////////// OrtLiteCustomFunc //////////////////////////////// @@ -891,9 +908,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { ComputeFn compute_fn, ShapeInferFn shape_infer_fn = {}, int start_ver = 1, - int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), - compute_fn_(compute_fn), - shape_infer_fn_(shape_infer_fn) { + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) { + compute_fn_ = reinterpret_cast(compute_fn); ParseArgs(input_types_, output_types_); OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { @@ -905,7 +921,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique(); - kernel->compute_fn_ = static_cast(this_)->compute_fn_; + auto me = static_cast(this_); + kernel->compute_fn_ = reinterpret_cast(me->compute_fn_); Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); auto self = static_cast(this_); @@ -931,9 +948,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { ComputeFnReturnStatus compute_fn_return_status, ShapeInferFn shape_infer_fn = {}, int start_ver = 1, - int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), - compute_fn_return_status_(compute_fn_return_status), - shape_infer_fn_(shape_infer_fn) { + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) { + compute_fn_return_status_ = reinterpret_cast(compute_fn_return_status); ParseArgs(input_types_, output_types_); OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { @@ -945,7 +961,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique(); - kernel->compute_fn_return_status_ = static_cast(this_)->compute_fn_return_status_; + auto me = static_cast(this_); + kernel->compute_fn_return_status_ = reinterpret_cast(me->compute_fn_return_status_); Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); auto self = static_cast(this_); @@ -965,10 +982,6 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { }; } } - - ComputeFn compute_fn_ = {}; - ComputeFnReturnStatus compute_fn_return_status_ = {}; - ShapeInferFn shape_infer_fn_ = {}; }; // struct OrtLiteCustomFunc /////////////////////////// OrtLiteCustomStruct /////////////////////////// @@ -1007,7 +1020,7 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { OrtLiteCustomStruct(const char* op_name, const char* execution_provider, int start_ver = 1, - int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver) { + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) { SetCompute(&CustomOp::Compute); OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { diff --git a/js/.eslintrc.js b/js/.eslintrc.js index 3a20429cfe8d2..fd30cb96a5bd0 100644 --- a/js/.eslintrc.js +++ b/js/.eslintrc.js @@ -146,9 +146,13 @@ module.exports = { }, { files: ['web/lib/**/*.ts'], excludedFiles: 'web/lib/wasm/proxy-worker/**/*', - parserOptions: { 'project': 'web/tsconfig.json' },rules: { + parserOptions: { 'project': 'web/tsconfig.json' }, + rules: { 'no-underscore-dangle': 'off', } + }, { + files: ['web/lib/wasm/proxy-worker/**/*.ts'], + parserOptions: { 'project': 'web/lib/wasm/proxy-worker/tsconfig.json' }, }, { files: ['web/lib/onnxjs/**/*.ts'], rules: { // TODO: those rules are useful. should turn on them in future (webgl refactor) diff --git a/js/node/package-lock.json b/js/node/package-lock.json index e8968bafc4a9f..c1cf8af4bb80e 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -22,7 +22,7 @@ "jsonc": "^2.0.0", "minimist": "^1.2.8", "node-addon-api": "^6.0.0", - "onnx-proto": "^8.0.1" + "protobufjs": "^7.2.4" } }, "../common": { @@ -97,12 +97,6 @@ "integrity": "sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==", "dev": true }, - "node_modules/@types/long": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/@types/long/-/long-4.0.2.tgz", - "integrity": "sha512-MqTGEo5bj5t157U6fA/BiDynNkn0YknVdh48CMPkTSpFTVmvao5UQmm7uEF6xBEo7qIMAlY/JSleYaE6VOdpaA==", - "dev": true - }, "node_modules/@types/minimist": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/@types/minimist/-/minimist-1.2.2.tgz", @@ -528,9 +522,9 @@ "dev": true }, "node_modules/long": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz", - "integrity": "sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", + "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==", "dev": true }, "node_modules/lru-cache": { @@ -663,15 +657,6 @@ "node": "^12.13.0 || ^14.15.0 || >=16.0.0" } }, - "node_modules/onnx-proto": { - "version": "8.0.1", - "resolved": "https://registry.npmjs.org/onnx-proto/-/onnx-proto-8.0.1.tgz", - "integrity": "sha512-ZpPTqp5dneh2bvavk/QpDsf20JJRArjqTkiMfshGmxR8ocjmfTk80fkW00FwLO7qRtybo9NPugcWQrumHYctLQ==", - "dev": true, - "dependencies": { - "protobufjs": "^6.11.2" - } - }, "node_modules/onnxruntime-common": { "resolved": "../common", "link": true @@ -690,9 +675,9 @@ } }, "node_modules/protobufjs": { - "version": "6.11.4", - "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-6.11.4.tgz", - "integrity": "sha512-5kQWPaJHi1WoCpjTGszzQ32PG2F4+wRY6BmAT4Vfw56Q2FZ4YZzK20xUYQH4YkfehY1e6QSICrJquM6xXZNcrw==", + "version": "7.2.5", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.5.tgz", + "integrity": "sha512-gGXRSXvxQ7UiPgfw8gevrfRWcTlSbOFg+p/N+JVJEK5VhueL2miT6qTymqAmjr1Q5WbOCyJbyrk6JfWKwlFn6A==", "dev": true, "hasInstallScript": true, "dependencies": { @@ -706,13 +691,11 @@ "@protobufjs/path": "^1.1.2", "@protobufjs/pool": "^1.1.0", "@protobufjs/utf8": "^1.1.0", - "@types/long": "^4.0.1", "@types/node": ">=13.7.0", - "long": "^4.0.0" + "long": "^5.0.0" }, - "bin": { - "pbjs": "bin/pbjs", - "pbts": "bin/pbts" + "engines": { + "node": ">=12.0.0" } }, "node_modules/proxy-from-env": { @@ -789,9 +772,9 @@ ] }, "node_modules/semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "dependencies": { "lru-cache": "^6.0.0" @@ -1070,12 +1053,6 @@ "integrity": "sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==", "dev": true }, - "@types/long": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/@types/long/-/long-4.0.2.tgz", - "integrity": "sha512-MqTGEo5bj5t157U6fA/BiDynNkn0YknVdh48CMPkTSpFTVmvao5UQmm7uEF6xBEo7qIMAlY/JSleYaE6VOdpaA==", - "dev": true - }, "@types/minimist": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/@types/minimist/-/minimist-1.2.2.tgz", @@ -1413,9 +1390,9 @@ "dev": true }, "long": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz", - "integrity": "sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", + "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==", "dev": true }, "lru-cache": { @@ -1523,15 +1500,6 @@ "set-blocking": "^2.0.0" } }, - "onnx-proto": { - "version": "8.0.1", - "resolved": "https://registry.npmjs.org/onnx-proto/-/onnx-proto-8.0.1.tgz", - "integrity": "sha512-ZpPTqp5dneh2bvavk/QpDsf20JJRArjqTkiMfshGmxR8ocjmfTk80fkW00FwLO7qRtybo9NPugcWQrumHYctLQ==", - "dev": true, - "requires": { - "protobufjs": "^6.11.2" - } - }, "onnxruntime-common": { "version": "file:../common", "requires": { @@ -1549,9 +1517,9 @@ } }, "protobufjs": { - "version": "6.11.4", - "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-6.11.4.tgz", - "integrity": "sha512-5kQWPaJHi1WoCpjTGszzQ32PG2F4+wRY6BmAT4Vfw56Q2FZ4YZzK20xUYQH4YkfehY1e6QSICrJquM6xXZNcrw==", + "version": "7.2.5", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.5.tgz", + "integrity": "sha512-gGXRSXvxQ7UiPgfw8gevrfRWcTlSbOFg+p/N+JVJEK5VhueL2miT6qTymqAmjr1Q5WbOCyJbyrk6JfWKwlFn6A==", "dev": true, "requires": { "@protobufjs/aspromise": "^1.1.2", @@ -1564,9 +1532,8 @@ "@protobufjs/path": "^1.1.2", "@protobufjs/pool": "^1.1.0", "@protobufjs/utf8": "^1.1.0", - "@types/long": "^4.0.1", "@types/node": ">=13.7.0", - "long": "^4.0.0" + "long": "^5.0.0" } }, "proxy-from-env": { @@ -1619,9 +1586,9 @@ "dev": true }, "semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "requires": { "lru-cache": "^6.0.0" diff --git a/js/node/package.json b/js/node/package.json index 0f8f0e9d2260c..8e591d8f46b9d 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -19,6 +19,7 @@ }, "scripts": { "buildr": "tsc && node ./script/build --config=RelWithDebInfo", + "preprepare": "node -e \"require('node:fs').copyFileSync('./node_modules/long/index.d.ts', './node_modules/long/umd/index.d.ts')\"", "prepare": "tsc --build script test .", "rebuild": "tsc && node ./script/build --rebuild", "rebuildd": "tsc && node ./script/build --rebuild --config=Debug", @@ -39,7 +40,7 @@ "jsonc": "^2.0.0", "minimist": "^1.2.8", "node-addon-api": "^6.0.0", - "onnx-proto": "^8.0.1" + "protobufjs": "^7.2.4" }, "main": "dist/index.js", "os": [ diff --git a/js/node/test/ort-schema/protobuf/.gitignore b/js/node/test/ort-schema/protobuf/.gitignore new file mode 100644 index 0000000000000..092bb6c1c9fb4 --- /dev/null +++ b/js/node/test/ort-schema/protobuf/.gitignore @@ -0,0 +1,2 @@ +!onnx.js +!onnx.d.ts diff --git a/js/node/test/ort-schema/protobuf/README.md b/js/node/test/ort-schema/protobuf/README.md new file mode 100644 index 0000000000000..f5f52c602f1ad --- /dev/null +++ b/js/node/test/ort-schema/protobuf/README.md @@ -0,0 +1,21 @@ +# ONNX protobuf + +This directory contains generated protobuf definition for onnx: + +- onnx.js +- onnx.d.ts + +These files are generated from [a fork of onnx-proto](https://github.com/fs-eire/onnx-proto/tree/update-v9). + +The ONNX protobuf uses protobufjs@7.2.4, which depends on long@5.2.3, the version contains 2 bugs: + +- type export does not work with commonjs. described in https://github.com/dcodeIO/long.js/pull/124. added a "postinstall" script to fix. +- in the generated typescript declaration file 'onnx.d.ts', the following line: + ```ts + import Long = require("long"); + ``` + need to be replaced to fix type import error: + ```ts + import Long from "long"; + ``` + this replacement is done and code format is also applied to file 'onnx.d.ts'. diff --git a/js/node/test/ort-schema/protobuf/onnx.d.ts b/js/node/test/ort-schema/protobuf/onnx.d.ts new file mode 100644 index 0000000000000..c60264dca2a8d --- /dev/null +++ b/js/node/test/ort-schema/protobuf/onnx.d.ts @@ -0,0 +1,2627 @@ +import Long from 'long'; +import * as $protobuf from 'protobufjs'; + +/** Namespace onnx. */ +export namespace onnx { + + /** Version enum. */ + enum Version { + _START_VERSION = 0, + IR_VERSION_2017_10_10 = 1, + IR_VERSION_2017_10_30 = 2, + IR_VERSION_2017_11_3 = 3, + IR_VERSION_2019_1_22 = 4, + IR_VERSION_2019_3_18 = 5, + IR_VERSION_2019_9_19 = 6, + IR_VERSION_2020_5_8 = 7, + IR_VERSION_2021_7_30 = 8, + IR_VERSION = 9 + } + + /** Properties of an AttributeProto. */ + interface IAttributeProto { + /** AttributeProto name */ + name?: (string|null); + + /** AttributeProto refAttrName */ + refAttrName?: (string|null); + + /** AttributeProto docString */ + docString?: (string|null); + + /** AttributeProto type */ + type?: (onnx.AttributeProto.AttributeType|null); + + /** AttributeProto f */ + f?: (number|null); + + /** AttributeProto i */ + i?: (number|Long|null); + + /** AttributeProto s */ + s?: (Uint8Array|null); + + /** AttributeProto t */ + t?: (onnx.ITensorProto|null); + + /** AttributeProto g */ + g?: (onnx.IGraphProto|null); + + /** AttributeProto sparseTensor */ + sparseTensor?: (onnx.ISparseTensorProto|null); + + /** AttributeProto tp */ + tp?: (onnx.ITypeProto|null); + + /** AttributeProto floats */ + floats?: (number[]|null); + + /** AttributeProto ints */ + ints?: ((number | Long)[]|null); + + /** AttributeProto strings */ + strings?: (Uint8Array[]|null); + + /** AttributeProto tensors */ + tensors?: (onnx.ITensorProto[]|null); + + /** AttributeProto graphs */ + graphs?: (onnx.IGraphProto[]|null); + + /** AttributeProto sparseTensors */ + sparseTensors?: (onnx.ISparseTensorProto[]|null); + + /** AttributeProto typeProtos */ + typeProtos?: (onnx.ITypeProto[]|null); + } + + /** Represents an AttributeProto. */ + class AttributeProto implements IAttributeProto { + /** + * Constructs a new AttributeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IAttributeProto); + + /** AttributeProto name. */ + public name: string; + + /** AttributeProto refAttrName. */ + public refAttrName: string; + + /** AttributeProto docString. */ + public docString: string; + + /** AttributeProto type. */ + public type: onnx.AttributeProto.AttributeType; + + /** AttributeProto f. */ + public f: number; + + /** AttributeProto i. */ + public i: (number|Long); + + /** AttributeProto s. */ + public s: Uint8Array; + + /** AttributeProto t. */ + public t?: (onnx.ITensorProto|null); + + /** AttributeProto g. */ + public g?: (onnx.IGraphProto|null); + + /** AttributeProto sparseTensor. */ + public sparseTensor?: (onnx.ISparseTensorProto|null); + + /** AttributeProto tp. */ + public tp?: (onnx.ITypeProto|null); + + /** AttributeProto floats. */ + public floats: number[]; + + /** AttributeProto ints. */ + public ints: (number|Long)[]; + + /** AttributeProto strings. */ + public strings: Uint8Array[]; + + /** AttributeProto tensors. */ + public tensors: onnx.ITensorProto[]; + + /** AttributeProto graphs. */ + public graphs: onnx.IGraphProto[]; + + /** AttributeProto sparseTensors. */ + public sparseTensors: onnx.ISparseTensorProto[]; + + /** AttributeProto typeProtos. */ + public typeProtos: onnx.ITypeProto[]; + + /** + * Creates a new AttributeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns AttributeProto instance + */ + public static create(properties?: onnx.IAttributeProto): onnx.AttributeProto; + + /** + * Encodes the specified AttributeProto message. Does not implicitly {@link onnx.AttributeProto.verify|verify} + * messages. + * @param message AttributeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IAttributeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified AttributeProto message, length delimited. Does not implicitly {@link + * onnx.AttributeProto.verify|verify} messages. + * @param message AttributeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IAttributeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes an AttributeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.AttributeProto; + + /** + * Decodes an AttributeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.AttributeProto; + + /** + * Verifies an AttributeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates an AttributeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns AttributeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.AttributeProto; + + /** + * Creates a plain object from an AttributeProto message. Also converts values to other types if specified. + * @param message AttributeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.AttributeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this AttributeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for AttributeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace AttributeProto { + + /** AttributeType enum. */ + enum AttributeType { + UNDEFINED = 0, + FLOAT = 1, + INT = 2, + STRING = 3, + TENSOR = 4, + GRAPH = 5, + SPARSE_TENSOR = 11, + TYPE_PROTO = 13, + FLOATS = 6, + INTS = 7, + STRINGS = 8, + TENSORS = 9, + GRAPHS = 10, + SPARSE_TENSORS = 12, + TYPE_PROTOS = 14 + } + } + + /** Properties of a ValueInfoProto. */ + interface IValueInfoProto { + /** ValueInfoProto name */ + name?: (string|null); + + /** ValueInfoProto type */ + type?: (onnx.ITypeProto|null); + + /** ValueInfoProto docString */ + docString?: (string|null); + } + + /** Represents a ValueInfoProto. */ + class ValueInfoProto implements IValueInfoProto { + /** + * Constructs a new ValueInfoProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IValueInfoProto); + + /** ValueInfoProto name. */ + public name: string; + + /** ValueInfoProto type. */ + public type?: (onnx.ITypeProto|null); + + /** ValueInfoProto docString. */ + public docString: string; + + /** + * Creates a new ValueInfoProto instance using the specified properties. + * @param [properties] Properties to set + * @returns ValueInfoProto instance + */ + public static create(properties?: onnx.IValueInfoProto): onnx.ValueInfoProto; + + /** + * Encodes the specified ValueInfoProto message. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} + * messages. + * @param message ValueInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IValueInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified ValueInfoProto message, length delimited. Does not implicitly {@link + * onnx.ValueInfoProto.verify|verify} messages. + * @param message ValueInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IValueInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.ValueInfoProto; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.ValueInfoProto; + + /** + * Verifies a ValueInfoProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a ValueInfoProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns ValueInfoProto + */ + public static fromObject(object: {[k: string]: any}): onnx.ValueInfoProto; + + /** + * Creates a plain object from a ValueInfoProto message. Also converts values to other types if specified. + * @param message ValueInfoProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.ValueInfoProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this ValueInfoProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for ValueInfoProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a NodeProto. */ + interface INodeProto { + /** NodeProto input */ + input?: (string[]|null); + + /** NodeProto output */ + output?: (string[]|null); + + /** NodeProto name */ + name?: (string|null); + + /** NodeProto opType */ + opType?: (string|null); + + /** NodeProto domain */ + domain?: (string|null); + + /** NodeProto attribute */ + attribute?: (onnx.IAttributeProto[]|null); + + /** NodeProto docString */ + docString?: (string|null); + } + + /** Represents a NodeProto. */ + class NodeProto implements INodeProto { + /** + * Constructs a new NodeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.INodeProto); + + /** NodeProto input. */ + public input: string[]; + + /** NodeProto output. */ + public output: string[]; + + /** NodeProto name. */ + public name: string; + + /** NodeProto opType. */ + public opType: string; + + /** NodeProto domain. */ + public domain: string; + + /** NodeProto attribute. */ + public attribute: onnx.IAttributeProto[]; + + /** NodeProto docString. */ + public docString: string; + + /** + * Creates a new NodeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns NodeProto instance + */ + public static create(properties?: onnx.INodeProto): onnx.NodeProto; + + /** + * Encodes the specified NodeProto message. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @param message NodeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.INodeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified NodeProto message, length delimited. Does not implicitly {@link + * onnx.NodeProto.verify|verify} messages. + * @param message NodeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.INodeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a NodeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.NodeProto; + + /** + * Decodes a NodeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.NodeProto; + + /** + * Verifies a NodeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a NodeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns NodeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.NodeProto; + + /** + * Creates a plain object from a NodeProto message. Also converts values to other types if specified. + * @param message NodeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.NodeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this NodeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for NodeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TrainingInfoProto. */ + interface ITrainingInfoProto { + /** TrainingInfoProto initialization */ + initialization?: (onnx.IGraphProto|null); + + /** TrainingInfoProto algorithm */ + algorithm?: (onnx.IGraphProto|null); + + /** TrainingInfoProto initializationBinding */ + initializationBinding?: (onnx.IStringStringEntryProto[]|null); + + /** TrainingInfoProto updateBinding */ + updateBinding?: (onnx.IStringStringEntryProto[]|null); + } + + /** Represents a TrainingInfoProto. */ + class TrainingInfoProto implements ITrainingInfoProto { + /** + * Constructs a new TrainingInfoProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITrainingInfoProto); + + /** TrainingInfoProto initialization. */ + public initialization?: (onnx.IGraphProto|null); + + /** TrainingInfoProto algorithm. */ + public algorithm?: (onnx.IGraphProto|null); + + /** TrainingInfoProto initializationBinding. */ + public initializationBinding: onnx.IStringStringEntryProto[]; + + /** TrainingInfoProto updateBinding. */ + public updateBinding: onnx.IStringStringEntryProto[]; + + /** + * Creates a new TrainingInfoProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TrainingInfoProto instance + */ + public static create(properties?: onnx.ITrainingInfoProto): onnx.TrainingInfoProto; + + /** + * Encodes the specified TrainingInfoProto message. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} + * messages. + * @param message TrainingInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITrainingInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TrainingInfoProto message, length delimited. Does not implicitly {@link + * onnx.TrainingInfoProto.verify|verify} messages. + * @param message TrainingInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITrainingInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TrainingInfoProto; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TrainingInfoProto; + + /** + * Verifies a TrainingInfoProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TrainingInfoProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TrainingInfoProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TrainingInfoProto; + + /** + * Creates a plain object from a TrainingInfoProto message. Also converts values to other types if specified. + * @param message TrainingInfoProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TrainingInfoProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TrainingInfoProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TrainingInfoProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a ModelProto. */ + interface IModelProto { + /** ModelProto irVersion */ + irVersion?: (number|Long|null); + + /** ModelProto opsetImport */ + opsetImport?: (onnx.IOperatorSetIdProto[]|null); + + /** ModelProto producerName */ + producerName?: (string|null); + + /** ModelProto producerVersion */ + producerVersion?: (string|null); + + /** ModelProto domain */ + domain?: (string|null); + + /** ModelProto modelVersion */ + modelVersion?: (number|Long|null); + + /** ModelProto docString */ + docString?: (string|null); + + /** ModelProto graph */ + graph?: (onnx.IGraphProto|null); + + /** ModelProto metadataProps */ + metadataProps?: (onnx.IStringStringEntryProto[]|null); + + /** ModelProto trainingInfo */ + trainingInfo?: (onnx.ITrainingInfoProto[]|null); + + /** ModelProto functions */ + functions?: (onnx.IFunctionProto[]|null); + } + + /** Represents a ModelProto. */ + class ModelProto implements IModelProto { + /** + * Constructs a new ModelProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IModelProto); + + /** ModelProto irVersion. */ + public irVersion: (number|Long); + + /** ModelProto opsetImport. */ + public opsetImport: onnx.IOperatorSetIdProto[]; + + /** ModelProto producerName. */ + public producerName: string; + + /** ModelProto producerVersion. */ + public producerVersion: string; + + /** ModelProto domain. */ + public domain: string; + + /** ModelProto modelVersion. */ + public modelVersion: (number|Long); + + /** ModelProto docString. */ + public docString: string; + + /** ModelProto graph. */ + public graph?: (onnx.IGraphProto|null); + + /** ModelProto metadataProps. */ + public metadataProps: onnx.IStringStringEntryProto[]; + + /** ModelProto trainingInfo. */ + public trainingInfo: onnx.ITrainingInfoProto[]; + + /** ModelProto functions. */ + public functions: onnx.IFunctionProto[]; + + /** + * Creates a new ModelProto instance using the specified properties. + * @param [properties] Properties to set + * @returns ModelProto instance + */ + public static create(properties?: onnx.IModelProto): onnx.ModelProto; + + /** + * Encodes the specified ModelProto message. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @param message ModelProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IModelProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified ModelProto message, length delimited. Does not implicitly {@link + * onnx.ModelProto.verify|verify} messages. + * @param message ModelProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IModelProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a ModelProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.ModelProto; + + /** + * Decodes a ModelProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.ModelProto; + + /** + * Verifies a ModelProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a ModelProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns ModelProto + */ + public static fromObject(object: {[k: string]: any}): onnx.ModelProto; + + /** + * Creates a plain object from a ModelProto message. Also converts values to other types if specified. + * @param message ModelProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.ModelProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this ModelProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for ModelProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a StringStringEntryProto. */ + interface IStringStringEntryProto { + /** StringStringEntryProto key */ + key?: (string|null); + + /** StringStringEntryProto value */ + value?: (string|null); + } + + /** Represents a StringStringEntryProto. */ + class StringStringEntryProto implements IStringStringEntryProto { + /** + * Constructs a new StringStringEntryProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IStringStringEntryProto); + + /** StringStringEntryProto key. */ + public key: string; + + /** StringStringEntryProto value. */ + public value: string; + + /** + * Creates a new StringStringEntryProto instance using the specified properties. + * @param [properties] Properties to set + * @returns StringStringEntryProto instance + */ + public static create(properties?: onnx.IStringStringEntryProto): onnx.StringStringEntryProto; + + /** + * Encodes the specified StringStringEntryProto message. Does not implicitly {@link + * onnx.StringStringEntryProto.verify|verify} messages. + * @param message StringStringEntryProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IStringStringEntryProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified StringStringEntryProto message, length delimited. Does not implicitly {@link + * onnx.StringStringEntryProto.verify|verify} messages. + * @param message StringStringEntryProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IStringStringEntryProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.StringStringEntryProto; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.StringStringEntryProto; + + /** + * Verifies a StringStringEntryProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a StringStringEntryProto message from a plain object. Also converts values to their respective internal + * types. + * @param object Plain object + * @returns StringStringEntryProto + */ + public static fromObject(object: {[k: string]: any}): onnx.StringStringEntryProto; + + /** + * Creates a plain object from a StringStringEntryProto message. Also converts values to other types if specified. + * @param message StringStringEntryProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.StringStringEntryProto, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this StringStringEntryProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for StringStringEntryProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TensorAnnotation. */ + interface ITensorAnnotation { + /** TensorAnnotation tensorName */ + tensorName?: (string|null); + + /** TensorAnnotation quantParameterTensorNames */ + quantParameterTensorNames?: (onnx.IStringStringEntryProto[]|null); + } + + /** Represents a TensorAnnotation. */ + class TensorAnnotation implements ITensorAnnotation { + /** + * Constructs a new TensorAnnotation. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITensorAnnotation); + + /** TensorAnnotation tensorName. */ + public tensorName: string; + + /** TensorAnnotation quantParameterTensorNames. */ + public quantParameterTensorNames: onnx.IStringStringEntryProto[]; + + /** + * Creates a new TensorAnnotation instance using the specified properties. + * @param [properties] Properties to set + * @returns TensorAnnotation instance + */ + public static create(properties?: onnx.ITensorAnnotation): onnx.TensorAnnotation; + + /** + * Encodes the specified TensorAnnotation message. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} + * messages. + * @param message TensorAnnotation message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITensorAnnotation, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TensorAnnotation message, length delimited. Does not implicitly {@link + * onnx.TensorAnnotation.verify|verify} messages. + * @param message TensorAnnotation message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITensorAnnotation, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorAnnotation; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorAnnotation; + + /** + * Verifies a TensorAnnotation message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TensorAnnotation message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TensorAnnotation + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorAnnotation; + + /** + * Creates a plain object from a TensorAnnotation message. Also converts values to other types if specified. + * @param message TensorAnnotation + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorAnnotation, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TensorAnnotation to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TensorAnnotation + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a GraphProto. */ + interface IGraphProto { + /** GraphProto node */ + node?: (onnx.INodeProto[]|null); + + /** GraphProto name */ + name?: (string|null); + + /** GraphProto initializer */ + initializer?: (onnx.ITensorProto[]|null); + + /** GraphProto sparseInitializer */ + sparseInitializer?: (onnx.ISparseTensorProto[]|null); + + /** GraphProto docString */ + docString?: (string|null); + + /** GraphProto input */ + input?: (onnx.IValueInfoProto[]|null); + + /** GraphProto output */ + output?: (onnx.IValueInfoProto[]|null); + + /** GraphProto valueInfo */ + valueInfo?: (onnx.IValueInfoProto[]|null); + + /** GraphProto quantizationAnnotation */ + quantizationAnnotation?: (onnx.ITensorAnnotation[]|null); + } + + /** Represents a GraphProto. */ + class GraphProto implements IGraphProto { + /** + * Constructs a new GraphProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IGraphProto); + + /** GraphProto node. */ + public node: onnx.INodeProto[]; + + /** GraphProto name. */ + public name: string; + + /** GraphProto initializer. */ + public initializer: onnx.ITensorProto[]; + + /** GraphProto sparseInitializer. */ + public sparseInitializer: onnx.ISparseTensorProto[]; + + /** GraphProto docString. */ + public docString: string; + + /** GraphProto input. */ + public input: onnx.IValueInfoProto[]; + + /** GraphProto output. */ + public output: onnx.IValueInfoProto[]; + + /** GraphProto valueInfo. */ + public valueInfo: onnx.IValueInfoProto[]; + + /** GraphProto quantizationAnnotation. */ + public quantizationAnnotation: onnx.ITensorAnnotation[]; + + /** + * Creates a new GraphProto instance using the specified properties. + * @param [properties] Properties to set + * @returns GraphProto instance + */ + public static create(properties?: onnx.IGraphProto): onnx.GraphProto; + + /** + * Encodes the specified GraphProto message. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @param message GraphProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IGraphProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified GraphProto message, length delimited. Does not implicitly {@link + * onnx.GraphProto.verify|verify} messages. + * @param message GraphProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IGraphProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a GraphProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.GraphProto; + + /** + * Decodes a GraphProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.GraphProto; + + /** + * Verifies a GraphProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a GraphProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns GraphProto + */ + public static fromObject(object: {[k: string]: any}): onnx.GraphProto; + + /** + * Creates a plain object from a GraphProto message. Also converts values to other types if specified. + * @param message GraphProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.GraphProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this GraphProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for GraphProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TensorProto. */ + interface ITensorProto { + /** TensorProto dims */ + dims?: ((number | Long)[]|null); + + /** TensorProto dataType */ + dataType?: (number|null); + + /** TensorProto segment */ + segment?: (onnx.TensorProto.ISegment|null); + + /** TensorProto floatData */ + floatData?: (number[]|null); + + /** TensorProto int32Data */ + int32Data?: (number[]|null); + + /** TensorProto stringData */ + stringData?: (Uint8Array[]|null); + + /** TensorProto int64Data */ + int64Data?: ((number | Long)[]|null); + + /** TensorProto name */ + name?: (string|null); + + /** TensorProto docString */ + docString?: (string|null); + + /** TensorProto rawData */ + rawData?: (Uint8Array|null); + + /** TensorProto externalData */ + externalData?: (onnx.IStringStringEntryProto[]|null); + + /** TensorProto dataLocation */ + dataLocation?: (onnx.TensorProto.DataLocation|null); + + /** TensorProto doubleData */ + doubleData?: (number[]|null); + + /** TensorProto uint64Data */ + uint64Data?: ((number | Long)[]|null); + } + + /** Represents a TensorProto. */ + class TensorProto implements ITensorProto { + /** + * Constructs a new TensorProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITensorProto); + + /** TensorProto dims. */ + public dims: (number|Long)[]; + + /** TensorProto dataType. */ + public dataType: number; + + /** TensorProto segment. */ + public segment?: (onnx.TensorProto.ISegment|null); + + /** TensorProto floatData. */ + public floatData: number[]; + + /** TensorProto int32Data. */ + public int32Data: number[]; + + /** TensorProto stringData. */ + public stringData: Uint8Array[]; + + /** TensorProto int64Data. */ + public int64Data: (number|Long)[]; + + /** TensorProto name. */ + public name: string; + + /** TensorProto docString. */ + public docString: string; + + /** TensorProto rawData. */ + public rawData: Uint8Array; + + /** TensorProto externalData. */ + public externalData: onnx.IStringStringEntryProto[]; + + /** TensorProto dataLocation. */ + public dataLocation: onnx.TensorProto.DataLocation; + + /** TensorProto doubleData. */ + public doubleData: number[]; + + /** TensorProto uint64Data. */ + public uint64Data: (number|Long)[]; + + /** + * Creates a new TensorProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TensorProto instance + */ + public static create(properties?: onnx.ITensorProto): onnx.TensorProto; + + /** + * Encodes the specified TensorProto message. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @param message TensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TensorProto message, length delimited. Does not implicitly {@link + * onnx.TensorProto.verify|verify} messages. + * @param message TensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TensorProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorProto; + + /** + * Decodes a TensorProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorProto; + + /** + * Verifies a TensorProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TensorProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TensorProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorProto; + + /** + * Creates a plain object from a TensorProto message. Also converts values to other types if specified. + * @param message TensorProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TensorProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TensorProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace TensorProto { + + /** DataType enum. */ + enum DataType { + UNDEFINED = 0, + FLOAT = 1, + UINT8 = 2, + INT8 = 3, + UINT16 = 4, + INT16 = 5, + INT32 = 6, + INT64 = 7, + STRING = 8, + BOOL = 9, + FLOAT16 = 10, + DOUBLE = 11, + UINT32 = 12, + UINT64 = 13, + COMPLEX64 = 14, + COMPLEX128 = 15, + BFLOAT16 = 16, + FLOAT8E4M3FN = 17, + FLOAT8E4M3FNUZ = 18, + FLOAT8E5M2 = 19, + FLOAT8E5M2FNUZ = 20 + } + + /** Properties of a Segment. */ + interface ISegment { + /** Segment begin */ + begin?: (number|Long|null); + + /** Segment end */ + end?: (number|Long|null); + } + + /** Represents a Segment. */ + class Segment implements ISegment { + /** + * Constructs a new Segment. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TensorProto.ISegment); + + /** Segment begin. */ + public begin: (number|Long); + + /** Segment end. */ + public end: (number|Long); + + /** + * Creates a new Segment instance using the specified properties. + * @param [properties] Properties to set + * @returns Segment instance + */ + public static create(properties?: onnx.TensorProto.ISegment): onnx.TensorProto.Segment; + + /** + * Encodes the specified Segment message. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} + * messages. + * @param message Segment message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TensorProto.ISegment, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Segment message, length delimited. Does not implicitly {@link + * onnx.TensorProto.Segment.verify|verify} messages. + * @param message Segment message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TensorProto.ISegment, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Segment message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorProto.Segment; + + /** + * Decodes a Segment message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorProto.Segment; + + /** + * Verifies a Segment message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Segment message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Segment + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorProto.Segment; + + /** + * Creates a plain object from a Segment message. Also converts values to other types if specified. + * @param message Segment + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorProto.Segment, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Segment to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Segment + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** DataLocation enum. */ + enum DataLocation { DEFAULT = 0, EXTERNAL = 1 } + } + + /** Properties of a SparseTensorProto. */ + interface ISparseTensorProto { + /** SparseTensorProto values */ + values?: (onnx.ITensorProto|null); + + /** SparseTensorProto indices */ + indices?: (onnx.ITensorProto|null); + + /** SparseTensorProto dims */ + dims?: ((number | Long)[]|null); + } + + /** Represents a SparseTensorProto. */ + class SparseTensorProto implements ISparseTensorProto { + /** + * Constructs a new SparseTensorProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ISparseTensorProto); + + /** SparseTensorProto values. */ + public values?: (onnx.ITensorProto|null); + + /** SparseTensorProto indices. */ + public indices?: (onnx.ITensorProto|null); + + /** SparseTensorProto dims. */ + public dims: (number|Long)[]; + + /** + * Creates a new SparseTensorProto instance using the specified properties. + * @param [properties] Properties to set + * @returns SparseTensorProto instance + */ + public static create(properties?: onnx.ISparseTensorProto): onnx.SparseTensorProto; + + /** + * Encodes the specified SparseTensorProto message. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} + * messages. + * @param message SparseTensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ISparseTensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified SparseTensorProto message, length delimited. Does not implicitly {@link + * onnx.SparseTensorProto.verify|verify} messages. + * @param message SparseTensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ISparseTensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.SparseTensorProto; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.SparseTensorProto; + + /** + * Verifies a SparseTensorProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a SparseTensorProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns SparseTensorProto + */ + public static fromObject(object: {[k: string]: any}): onnx.SparseTensorProto; + + /** + * Creates a plain object from a SparseTensorProto message. Also converts values to other types if specified. + * @param message SparseTensorProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.SparseTensorProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this SparseTensorProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for SparseTensorProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TensorShapeProto. */ + interface ITensorShapeProto { + /** TensorShapeProto dim */ + dim?: (onnx.TensorShapeProto.IDimension[]|null); + } + + /** Represents a TensorShapeProto. */ + class TensorShapeProto implements ITensorShapeProto { + /** + * Constructs a new TensorShapeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITensorShapeProto); + + /** TensorShapeProto dim. */ + public dim: onnx.TensorShapeProto.IDimension[]; + + /** + * Creates a new TensorShapeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TensorShapeProto instance + */ + public static create(properties?: onnx.ITensorShapeProto): onnx.TensorShapeProto; + + /** + * Encodes the specified TensorShapeProto message. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} + * messages. + * @param message TensorShapeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITensorShapeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TensorShapeProto message, length delimited. Does not implicitly {@link + * onnx.TensorShapeProto.verify|verify} messages. + * @param message TensorShapeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITensorShapeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorShapeProto; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorShapeProto; + + /** + * Verifies a TensorShapeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TensorShapeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TensorShapeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorShapeProto; + + /** + * Creates a plain object from a TensorShapeProto message. Also converts values to other types if specified. + * @param message TensorShapeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorShapeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TensorShapeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TensorShapeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace TensorShapeProto { + + /** Properties of a Dimension. */ + interface IDimension { + /** Dimension dimValue */ + dimValue?: (number|Long|null); + + /** Dimension dimParam */ + dimParam?: (string|null); + + /** Dimension denotation */ + denotation?: (string|null); + } + + /** Represents a Dimension. */ + class Dimension implements IDimension { + /** + * Constructs a new Dimension. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TensorShapeProto.IDimension); + + /** Dimension dimValue. */ + public dimValue?: (number|Long|null); + + /** Dimension dimParam. */ + public dimParam?: (string|null); + + /** Dimension denotation. */ + public denotation: string; + + /** Dimension value. */ + public value?: ('dimValue'|'dimParam'); + + /** + * Creates a new Dimension instance using the specified properties. + * @param [properties] Properties to set + * @returns Dimension instance + */ + public static create(properties?: onnx.TensorShapeProto.IDimension): onnx.TensorShapeProto.Dimension; + + /** + * Encodes the specified Dimension message. Does not implicitly {@link + * onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @param message Dimension message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TensorShapeProto.IDimension, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Dimension message, length delimited. Does not implicitly {@link + * onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @param message Dimension message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TensorShapeProto.IDimension, writer?: $protobuf.Writer): + $protobuf.Writer; + + /** + * Decodes a Dimension message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorShapeProto.Dimension; + + /** + * Decodes a Dimension message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorShapeProto.Dimension; + + /** + * Verifies a Dimension message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Dimension message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Dimension + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorShapeProto.Dimension; + + /** + * Creates a plain object from a Dimension message. Also converts values to other types if specified. + * @param message Dimension + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorShapeProto.Dimension, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Dimension to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Dimension + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + } + + /** Properties of a TypeProto. */ + interface ITypeProto { + /** TypeProto tensorType */ + tensorType?: (onnx.TypeProto.ITensor|null); + + /** TypeProto sequenceType */ + sequenceType?: (onnx.TypeProto.ISequence|null); + + /** TypeProto mapType */ + mapType?: (onnx.TypeProto.IMap|null); + + /** TypeProto optionalType */ + optionalType?: (onnx.TypeProto.IOptional|null); + + /** TypeProto sparseTensorType */ + sparseTensorType?: (onnx.TypeProto.ISparseTensor|null); + + /** TypeProto denotation */ + denotation?: (string|null); + } + + /** Represents a TypeProto. */ + class TypeProto implements ITypeProto { + /** + * Constructs a new TypeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITypeProto); + + /** TypeProto tensorType. */ + public tensorType?: (onnx.TypeProto.ITensor|null); + + /** TypeProto sequenceType. */ + public sequenceType?: (onnx.TypeProto.ISequence|null); + + /** TypeProto mapType. */ + public mapType?: (onnx.TypeProto.IMap|null); + + /** TypeProto optionalType. */ + public optionalType?: (onnx.TypeProto.IOptional|null); + + /** TypeProto sparseTensorType. */ + public sparseTensorType?: (onnx.TypeProto.ISparseTensor|null); + + /** TypeProto denotation. */ + public denotation: string; + + /** TypeProto value. */ + public value?: ('tensorType'|'sequenceType'|'mapType'|'optionalType'|'sparseTensorType'); + + /** + * Creates a new TypeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TypeProto instance + */ + public static create(properties?: onnx.ITypeProto): onnx.TypeProto; + + /** + * Encodes the specified TypeProto message. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @param message TypeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITypeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TypeProto message, length delimited. Does not implicitly {@link + * onnx.TypeProto.verify|verify} messages. + * @param message TypeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITypeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TypeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto; + + /** + * Decodes a TypeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto; + + /** + * Verifies a TypeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TypeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TypeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto; + + /** + * Creates a plain object from a TypeProto message. Also converts values to other types if specified. + * @param message TypeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TypeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TypeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace TypeProto { + + /** Properties of a Tensor. */ + interface ITensor { + /** Tensor elemType */ + elemType?: (number|null); + + /** Tensor shape */ + shape?: (onnx.ITensorShapeProto|null); + } + + /** Represents a Tensor. */ + class Tensor implements ITensor { + /** + * Constructs a new Tensor. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.ITensor); + + /** Tensor elemType. */ + public elemType: number; + + /** Tensor shape. */ + public shape?: (onnx.ITensorShapeProto|null); + + /** + * Creates a new Tensor instance using the specified properties. + * @param [properties] Properties to set + * @returns Tensor instance + */ + public static create(properties?: onnx.TypeProto.ITensor): onnx.TypeProto.Tensor; + + /** + * Encodes the specified Tensor message. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @param message Tensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.ITensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Tensor message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Tensor.verify|verify} messages. + * @param message Tensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.ITensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Tensor message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Tensor; + + /** + * Decodes a Tensor message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Tensor; + + /** + * Verifies a Tensor message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Tensor message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Tensor + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Tensor; + + /** + * Creates a plain object from a Tensor message. Also converts values to other types if specified. + * @param message Tensor + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Tensor, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Tensor to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Tensor + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a Sequence. */ + interface ISequence { + /** Sequence elemType */ + elemType?: (onnx.ITypeProto|null); + } + + /** Represents a Sequence. */ + class Sequence implements ISequence { + /** + * Constructs a new Sequence. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.ISequence); + + /** Sequence elemType. */ + public elemType?: (onnx.ITypeProto|null); + + /** + * Creates a new Sequence instance using the specified properties. + * @param [properties] Properties to set + * @returns Sequence instance + */ + public static create(properties?: onnx.TypeProto.ISequence): onnx.TypeProto.Sequence; + + /** + * Encodes the specified Sequence message. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} + * messages. + * @param message Sequence message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.ISequence, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Sequence message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Sequence.verify|verify} messages. + * @param message Sequence message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.ISequence, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Sequence message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Sequence; + + /** + * Decodes a Sequence message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Sequence; + + /** + * Verifies a Sequence message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Sequence message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Sequence + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Sequence; + + /** + * Creates a plain object from a Sequence message. Also converts values to other types if specified. + * @param message Sequence + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Sequence, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Sequence to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Sequence + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a Map. */ + interface IMap { + /** Map keyType */ + keyType?: (number|null); + + /** Map valueType */ + valueType?: (onnx.ITypeProto|null); + } + + /** Represents a Map. */ + class Map implements IMap { + /** + * Constructs a new Map. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.IMap); + + /** Map keyType. */ + public keyType: number; + + /** Map valueType. */ + public valueType?: (onnx.ITypeProto|null); + + /** + * Creates a new Map instance using the specified properties. + * @param [properties] Properties to set + * @returns Map instance + */ + public static create(properties?: onnx.TypeProto.IMap): onnx.TypeProto.Map; + + /** + * Encodes the specified Map message. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @param message Map message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.IMap, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Map message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Map.verify|verify} messages. + * @param message Map message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.IMap, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Map message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Map; + + /** + * Decodes a Map message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Map; + + /** + * Verifies a Map message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Map message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Map + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Map; + + /** + * Creates a plain object from a Map message. Also converts values to other types if specified. + * @param message Map + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Map, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this Map to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Map + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of an Optional. */ + interface IOptional { + /** Optional elemType */ + elemType?: (onnx.ITypeProto|null); + } + + /** Represents an Optional. */ + class Optional implements IOptional { + /** + * Constructs a new Optional. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.IOptional); + + /** Optional elemType. */ + public elemType?: (onnx.ITypeProto|null); + + /** + * Creates a new Optional instance using the specified properties. + * @param [properties] Properties to set + * @returns Optional instance + */ + public static create(properties?: onnx.TypeProto.IOptional): onnx.TypeProto.Optional; + + /** + * Encodes the specified Optional message. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} + * messages. + * @param message Optional message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.IOptional, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Optional message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Optional.verify|verify} messages. + * @param message Optional message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.IOptional, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes an Optional message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Optional; + + /** + * Decodes an Optional message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Optional; + + /** + * Verifies an Optional message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates an Optional message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Optional + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Optional; + + /** + * Creates a plain object from an Optional message. Also converts values to other types if specified. + * @param message Optional + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Optional, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Optional to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Optional + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a SparseTensor. */ + interface ISparseTensor { + /** SparseTensor elemType */ + elemType?: (number|null); + + /** SparseTensor shape */ + shape?: (onnx.ITensorShapeProto|null); + } + + /** Represents a SparseTensor. */ + class SparseTensor implements ISparseTensor { + /** + * Constructs a new SparseTensor. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.ISparseTensor); + + /** SparseTensor elemType. */ + public elemType: number; + + /** SparseTensor shape. */ + public shape?: (onnx.ITensorShapeProto|null); + + /** + * Creates a new SparseTensor instance using the specified properties. + * @param [properties] Properties to set + * @returns SparseTensor instance + */ + public static create(properties?: onnx.TypeProto.ISparseTensor): onnx.TypeProto.SparseTensor; + + /** + * Encodes the specified SparseTensor message. Does not implicitly {@link + * onnx.TypeProto.SparseTensor.verify|verify} messages. + * @param message SparseTensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.ISparseTensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified SparseTensor message, length delimited. Does not implicitly {@link + * onnx.TypeProto.SparseTensor.verify|verify} messages. + * @param message SparseTensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.ISparseTensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a SparseTensor message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.SparseTensor; + + /** + * Decodes a SparseTensor message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.SparseTensor; + + /** + * Verifies a SparseTensor message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a SparseTensor message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns SparseTensor + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.SparseTensor; + + /** + * Creates a plain object from a SparseTensor message. Also converts values to other types if specified. + * @param message SparseTensor + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.SparseTensor, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this SparseTensor to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for SparseTensor + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + } + + /** Properties of an OperatorSetIdProto. */ + interface IOperatorSetIdProto { + /** OperatorSetIdProto domain */ + domain?: (string|null); + + /** OperatorSetIdProto version */ + version?: (number|Long|null); + } + + /** Represents an OperatorSetIdProto. */ + class OperatorSetIdProto implements IOperatorSetIdProto { + /** + * Constructs a new OperatorSetIdProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IOperatorSetIdProto); + + /** OperatorSetIdProto domain. */ + public domain: string; + + /** OperatorSetIdProto version. */ + public version: (number|Long); + + /** + * Creates a new OperatorSetIdProto instance using the specified properties. + * @param [properties] Properties to set + * @returns OperatorSetIdProto instance + */ + public static create(properties?: onnx.IOperatorSetIdProto): onnx.OperatorSetIdProto; + + /** + * Encodes the specified OperatorSetIdProto message. Does not implicitly {@link + * onnx.OperatorSetIdProto.verify|verify} messages. + * @param message OperatorSetIdProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IOperatorSetIdProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified OperatorSetIdProto message, length delimited. Does not implicitly {@link + * onnx.OperatorSetIdProto.verify|verify} messages. + * @param message OperatorSetIdProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IOperatorSetIdProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.OperatorSetIdProto; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.OperatorSetIdProto; + + /** + * Verifies an OperatorSetIdProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates an OperatorSetIdProto message from a plain object. Also converts values to their respective internal + * types. + * @param object Plain object + * @returns OperatorSetIdProto + */ + public static fromObject(object: {[k: string]: any}): onnx.OperatorSetIdProto; + + /** + * Creates a plain object from an OperatorSetIdProto message. Also converts values to other types if specified. + * @param message OperatorSetIdProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.OperatorSetIdProto, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this OperatorSetIdProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for OperatorSetIdProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** OperatorStatus enum. */ + enum OperatorStatus { EXPERIMENTAL = 0, STABLE = 1 } + + /** Properties of a FunctionProto. */ + interface IFunctionProto { + /** FunctionProto name */ + name?: (string|null); + + /** FunctionProto input */ + input?: (string[]|null); + + /** FunctionProto output */ + output?: (string[]|null); + + /** FunctionProto attribute */ + attribute?: (string[]|null); + + /** FunctionProto attributeProto */ + attributeProto?: (onnx.IAttributeProto[]|null); + + /** FunctionProto node */ + node?: (onnx.INodeProto[]|null); + + /** FunctionProto docString */ + docString?: (string|null); + + /** FunctionProto opsetImport */ + opsetImport?: (onnx.IOperatorSetIdProto[]|null); + + /** FunctionProto domain */ + domain?: (string|null); + } + + /** Represents a FunctionProto. */ + class FunctionProto implements IFunctionProto { + /** + * Constructs a new FunctionProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IFunctionProto); + + /** FunctionProto name. */ + public name: string; + + /** FunctionProto input. */ + public input: string[]; + + /** FunctionProto output. */ + public output: string[]; + + /** FunctionProto attribute. */ + public attribute: string[]; + + /** FunctionProto attributeProto. */ + public attributeProto: onnx.IAttributeProto[]; + + /** FunctionProto node. */ + public node: onnx.INodeProto[]; + + /** FunctionProto docString. */ + public docString: string; + + /** FunctionProto opsetImport. */ + public opsetImport: onnx.IOperatorSetIdProto[]; + + /** FunctionProto domain. */ + public domain: string; + + /** + * Creates a new FunctionProto instance using the specified properties. + * @param [properties] Properties to set + * @returns FunctionProto instance + */ + public static create(properties?: onnx.IFunctionProto): onnx.FunctionProto; + + /** + * Encodes the specified FunctionProto message. Does not implicitly {@link onnx.FunctionProto.verify|verify} + * messages. + * @param message FunctionProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IFunctionProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified FunctionProto message, length delimited. Does not implicitly {@link + * onnx.FunctionProto.verify|verify} messages. + * @param message FunctionProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IFunctionProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a FunctionProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.FunctionProto; + + /** + * Decodes a FunctionProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.FunctionProto; + + /** + * Verifies a FunctionProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a FunctionProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns FunctionProto + */ + public static fromObject(object: {[k: string]: any}): onnx.FunctionProto; + + /** + * Creates a plain object from a FunctionProto message. Also converts values to other types if specified. + * @param message FunctionProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.FunctionProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this FunctionProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for FunctionProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } +} diff --git a/js/node/test/ort-schema/protobuf/onnx.js b/js/node/test/ort-schema/protobuf/onnx.js new file mode 100644 index 0000000000000..681855132d4e8 --- /dev/null +++ b/js/node/test/ort-schema/protobuf/onnx.js @@ -0,0 +1,7658 @@ +/*eslint-disable block-scoped-var, id-length, no-control-regex, no-magic-numbers, no-prototype-builtins, no-redeclare, no-shadow, no-var, sort-vars*/ +"use strict"; + +var $protobuf = require("protobufjs/minimal"); + +// Common aliases +var $Reader = $protobuf.Reader, $Writer = $protobuf.Writer, $util = $protobuf.util; + +// Exported root namespace +var $root = $protobuf.roots["default"] || ($protobuf.roots["default"] = {}); + +$root.onnx = (function() { + + /** + * Namespace onnx. + * @exports onnx + * @namespace + */ + var onnx = {}; + + /** + * Version enum. + * @name onnx.Version + * @enum {number} + * @property {number} _START_VERSION=0 _START_VERSION value + * @property {number} IR_VERSION_2017_10_10=1 IR_VERSION_2017_10_10 value + * @property {number} IR_VERSION_2017_10_30=2 IR_VERSION_2017_10_30 value + * @property {number} IR_VERSION_2017_11_3=3 IR_VERSION_2017_11_3 value + * @property {number} IR_VERSION_2019_1_22=4 IR_VERSION_2019_1_22 value + * @property {number} IR_VERSION_2019_3_18=5 IR_VERSION_2019_3_18 value + * @property {number} IR_VERSION_2019_9_19=6 IR_VERSION_2019_9_19 value + * @property {number} IR_VERSION_2020_5_8=7 IR_VERSION_2020_5_8 value + * @property {number} IR_VERSION_2021_7_30=8 IR_VERSION_2021_7_30 value + * @property {number} IR_VERSION=9 IR_VERSION value + */ + onnx.Version = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "_START_VERSION"] = 0; + values[valuesById[1] = "IR_VERSION_2017_10_10"] = 1; + values[valuesById[2] = "IR_VERSION_2017_10_30"] = 2; + values[valuesById[3] = "IR_VERSION_2017_11_3"] = 3; + values[valuesById[4] = "IR_VERSION_2019_1_22"] = 4; + values[valuesById[5] = "IR_VERSION_2019_3_18"] = 5; + values[valuesById[6] = "IR_VERSION_2019_9_19"] = 6; + values[valuesById[7] = "IR_VERSION_2020_5_8"] = 7; + values[valuesById[8] = "IR_VERSION_2021_7_30"] = 8; + values[valuesById[9] = "IR_VERSION"] = 9; + return values; + })(); + + onnx.AttributeProto = (function() { + + /** + * Properties of an AttributeProto. + * @memberof onnx + * @interface IAttributeProto + * @property {string|null} [name] AttributeProto name + * @property {string|null} [refAttrName] AttributeProto refAttrName + * @property {string|null} [docString] AttributeProto docString + * @property {onnx.AttributeProto.AttributeType|null} [type] AttributeProto type + * @property {number|null} [f] AttributeProto f + * @property {number|Long|null} [i] AttributeProto i + * @property {Uint8Array|null} [s] AttributeProto s + * @property {onnx.ITensorProto|null} [t] AttributeProto t + * @property {onnx.IGraphProto|null} [g] AttributeProto g + * @property {onnx.ISparseTensorProto|null} [sparseTensor] AttributeProto sparseTensor + * @property {onnx.ITypeProto|null} [tp] AttributeProto tp + * @property {Array.|null} [floats] AttributeProto floats + * @property {Array.|null} [ints] AttributeProto ints + * @property {Array.|null} [strings] AttributeProto strings + * @property {Array.|null} [tensors] AttributeProto tensors + * @property {Array.|null} [graphs] AttributeProto graphs + * @property {Array.|null} [sparseTensors] AttributeProto sparseTensors + * @property {Array.|null} [typeProtos] AttributeProto typeProtos + */ + + /** + * Constructs a new AttributeProto. + * @memberof onnx + * @classdesc Represents an AttributeProto. + * @implements IAttributeProto + * @constructor + * @param {onnx.IAttributeProto=} [properties] Properties to set + */ + function AttributeProto(properties) { + this.floats = []; + this.ints = []; + this.strings = []; + this.tensors = []; + this.graphs = []; + this.sparseTensors = []; + this.typeProtos = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * AttributeProto name. + * @member {string} name + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.name = ""; + + /** + * AttributeProto refAttrName. + * @member {string} refAttrName + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.refAttrName = ""; + + /** + * AttributeProto docString. + * @member {string} docString + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.docString = ""; + + /** + * AttributeProto type. + * @member {onnx.AttributeProto.AttributeType} type + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.type = 0; + + /** + * AttributeProto f. + * @member {number} f + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.f = 0; + + /** + * AttributeProto i. + * @member {number|Long} i + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.i = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * AttributeProto s. + * @member {Uint8Array} s + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.s = $util.newBuffer([]); + + /** + * AttributeProto t. + * @member {onnx.ITensorProto|null|undefined} t + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.t = null; + + /** + * AttributeProto g. + * @member {onnx.IGraphProto|null|undefined} g + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.g = null; + + /** + * AttributeProto sparseTensor. + * @member {onnx.ISparseTensorProto|null|undefined} sparseTensor + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.sparseTensor = null; + + /** + * AttributeProto tp. + * @member {onnx.ITypeProto|null|undefined} tp + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.tp = null; + + /** + * AttributeProto floats. + * @member {Array.} floats + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.floats = $util.emptyArray; + + /** + * AttributeProto ints. + * @member {Array.} ints + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.ints = $util.emptyArray; + + /** + * AttributeProto strings. + * @member {Array.} strings + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.strings = $util.emptyArray; + + /** + * AttributeProto tensors. + * @member {Array.} tensors + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.tensors = $util.emptyArray; + + /** + * AttributeProto graphs. + * @member {Array.} graphs + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.graphs = $util.emptyArray; + + /** + * AttributeProto sparseTensors. + * @member {Array.} sparseTensors + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.sparseTensors = $util.emptyArray; + + /** + * AttributeProto typeProtos. + * @member {Array.} typeProtos + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.typeProtos = $util.emptyArray; + + /** + * Creates a new AttributeProto instance using the specified properties. + * @function create + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto=} [properties] Properties to set + * @returns {onnx.AttributeProto} AttributeProto instance + */ + AttributeProto.create = function create(properties) { + return new AttributeProto(properties); + }; + + /** + * Encodes the specified AttributeProto message. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. + * @function encode + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + AttributeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); + if (message.f != null && Object.hasOwnProperty.call(message, "f")) + writer.uint32(/* id 2, wireType 5 =*/21).float(message.f); + if (message.i != null && Object.hasOwnProperty.call(message, "i")) + writer.uint32(/* id 3, wireType 0 =*/24).int64(message.i); + if (message.s != null && Object.hasOwnProperty.call(message, "s")) + writer.uint32(/* id 4, wireType 2 =*/34).bytes(message.s); + if (message.t != null && Object.hasOwnProperty.call(message, "t")) + $root.onnx.TensorProto.encode(message.t, writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.g != null && Object.hasOwnProperty.call(message, "g")) + $root.onnx.GraphProto.encode(message.g, writer.uint32(/* id 6, wireType 2 =*/50).fork()).ldelim(); + if (message.floats != null && message.floats.length) { + writer.uint32(/* id 7, wireType 2 =*/58).fork(); + for (var i = 0; i < message.floats.length; ++i) + writer.float(message.floats[i]); + writer.ldelim(); + } + if (message.ints != null && message.ints.length) { + writer.uint32(/* id 8, wireType 2 =*/66).fork(); + for (var i = 0; i < message.ints.length; ++i) + writer.int64(message.ints[i]); + writer.ldelim(); + } + if (message.strings != null && message.strings.length) + for (var i = 0; i < message.strings.length; ++i) + writer.uint32(/* id 9, wireType 2 =*/74).bytes(message.strings[i]); + if (message.tensors != null && message.tensors.length) + for (var i = 0; i < message.tensors.length; ++i) + $root.onnx.TensorProto.encode(message.tensors[i], writer.uint32(/* id 10, wireType 2 =*/82).fork()).ldelim(); + if (message.graphs != null && message.graphs.length) + for (var i = 0; i < message.graphs.length; ++i) + $root.onnx.GraphProto.encode(message.graphs[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 13, wireType 2 =*/106).string(message.docString); + if (message.tp != null && Object.hasOwnProperty.call(message, "tp")) + $root.onnx.TypeProto.encode(message.tp, writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); + if (message.typeProtos != null && message.typeProtos.length) + for (var i = 0; i < message.typeProtos.length; ++i) + $root.onnx.TypeProto.encode(message.typeProtos[i], writer.uint32(/* id 15, wireType 2 =*/122).fork()).ldelim(); + if (message.type != null && Object.hasOwnProperty.call(message, "type")) + writer.uint32(/* id 20, wireType 0 =*/160).int32(message.type); + if (message.refAttrName != null && Object.hasOwnProperty.call(message, "refAttrName")) + writer.uint32(/* id 21, wireType 2 =*/170).string(message.refAttrName); + if (message.sparseTensor != null && Object.hasOwnProperty.call(message, "sparseTensor")) + $root.onnx.SparseTensorProto.encode(message.sparseTensor, writer.uint32(/* id 22, wireType 2 =*/178).fork()).ldelim(); + if (message.sparseTensors != null && message.sparseTensors.length) + for (var i = 0; i < message.sparseTensors.length; ++i) + $root.onnx.SparseTensorProto.encode(message.sparseTensors[i], writer.uint32(/* id 23, wireType 2 =*/186).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified AttributeProto message, length delimited. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + AttributeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an AttributeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.AttributeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.AttributeProto} AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + AttributeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.AttributeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 21: { + message.refAttrName = reader.string(); + break; + } + case 13: { + message.docString = reader.string(); + break; + } + case 20: { + message.type = reader.int32(); + break; + } + case 2: { + message.f = reader.float(); + break; + } + case 3: { + message.i = reader.int64(); + break; + } + case 4: { + message.s = reader.bytes(); + break; + } + case 5: { + message.t = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 6: { + message.g = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 22: { + message.sparseTensor = $root.onnx.SparseTensorProto.decode(reader, reader.uint32()); + break; + } + case 14: { + message.tp = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + case 7: { + if (!(message.floats && message.floats.length)) + message.floats = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.floats.push(reader.float()); + } else + message.floats.push(reader.float()); + break; + } + case 8: { + if (!(message.ints && message.ints.length)) + message.ints = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.ints.push(reader.int64()); + } else + message.ints.push(reader.int64()); + break; + } + case 9: { + if (!(message.strings && message.strings.length)) + message.strings = []; + message.strings.push(reader.bytes()); + break; + } + case 10: { + if (!(message.tensors && message.tensors.length)) + message.tensors = []; + message.tensors.push($root.onnx.TensorProto.decode(reader, reader.uint32())); + break; + } + case 11: { + if (!(message.graphs && message.graphs.length)) + message.graphs = []; + message.graphs.push($root.onnx.GraphProto.decode(reader, reader.uint32())); + break; + } + case 23: { + if (!(message.sparseTensors && message.sparseTensors.length)) + message.sparseTensors = []; + message.sparseTensors.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); + break; + } + case 15: { + if (!(message.typeProtos && message.typeProtos.length)) + message.typeProtos = []; + message.typeProtos.push($root.onnx.TypeProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an AttributeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.AttributeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.AttributeProto} AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + AttributeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an AttributeProto message. + * @function verify + * @memberof onnx.AttributeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + AttributeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.refAttrName != null && message.hasOwnProperty("refAttrName")) + if (!$util.isString(message.refAttrName)) + return "refAttrName: string expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.type != null && message.hasOwnProperty("type")) + switch (message.type) { + default: + return "type: enum value expected"; + case 0: + case 1: + case 2: + case 3: + case 4: + case 5: + case 11: + case 13: + case 6: + case 7: + case 8: + case 9: + case 10: + case 12: + case 14: + break; + } + if (message.f != null && message.hasOwnProperty("f")) + if (typeof message.f !== "number") + return "f: number expected"; + if (message.i != null && message.hasOwnProperty("i")) + if (!$util.isInteger(message.i) && !(message.i && $util.isInteger(message.i.low) && $util.isInteger(message.i.high))) + return "i: integer|Long expected"; + if (message.s != null && message.hasOwnProperty("s")) + if (!(message.s && typeof message.s.length === "number" || $util.isString(message.s))) + return "s: buffer expected"; + if (message.t != null && message.hasOwnProperty("t")) { + var error = $root.onnx.TensorProto.verify(message.t); + if (error) + return "t." + error; + } + if (message.g != null && message.hasOwnProperty("g")) { + var error = $root.onnx.GraphProto.verify(message.g); + if (error) + return "g." + error; + } + if (message.sparseTensor != null && message.hasOwnProperty("sparseTensor")) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseTensor); + if (error) + return "sparseTensor." + error; + } + if (message.tp != null && message.hasOwnProperty("tp")) { + var error = $root.onnx.TypeProto.verify(message.tp); + if (error) + return "tp." + error; + } + if (message.floats != null && message.hasOwnProperty("floats")) { + if (!Array.isArray(message.floats)) + return "floats: array expected"; + for (var i = 0; i < message.floats.length; ++i) + if (typeof message.floats[i] !== "number") + return "floats: number[] expected"; + } + if (message.ints != null && message.hasOwnProperty("ints")) { + if (!Array.isArray(message.ints)) + return "ints: array expected"; + for (var i = 0; i < message.ints.length; ++i) + if (!$util.isInteger(message.ints[i]) && !(message.ints[i] && $util.isInteger(message.ints[i].low) && $util.isInteger(message.ints[i].high))) + return "ints: integer|Long[] expected"; + } + if (message.strings != null && message.hasOwnProperty("strings")) { + if (!Array.isArray(message.strings)) + return "strings: array expected"; + for (var i = 0; i < message.strings.length; ++i) + if (!(message.strings[i] && typeof message.strings[i].length === "number" || $util.isString(message.strings[i]))) + return "strings: buffer[] expected"; + } + if (message.tensors != null && message.hasOwnProperty("tensors")) { + if (!Array.isArray(message.tensors)) + return "tensors: array expected"; + for (var i = 0; i < message.tensors.length; ++i) { + var error = $root.onnx.TensorProto.verify(message.tensors[i]); + if (error) + return "tensors." + error; + } + } + if (message.graphs != null && message.hasOwnProperty("graphs")) { + if (!Array.isArray(message.graphs)) + return "graphs: array expected"; + for (var i = 0; i < message.graphs.length; ++i) { + var error = $root.onnx.GraphProto.verify(message.graphs[i]); + if (error) + return "graphs." + error; + } + } + if (message.sparseTensors != null && message.hasOwnProperty("sparseTensors")) { + if (!Array.isArray(message.sparseTensors)) + return "sparseTensors: array expected"; + for (var i = 0; i < message.sparseTensors.length; ++i) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseTensors[i]); + if (error) + return "sparseTensors." + error; + } + } + if (message.typeProtos != null && message.hasOwnProperty("typeProtos")) { + if (!Array.isArray(message.typeProtos)) + return "typeProtos: array expected"; + for (var i = 0; i < message.typeProtos.length; ++i) { + var error = $root.onnx.TypeProto.verify(message.typeProtos[i]); + if (error) + return "typeProtos." + error; + } + } + return null; + }; + + /** + * Creates an AttributeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.AttributeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.AttributeProto} AttributeProto + */ + AttributeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.AttributeProto) + return object; + var message = new $root.onnx.AttributeProto(); + if (object.name != null) + message.name = String(object.name); + if (object.refAttrName != null) + message.refAttrName = String(object.refAttrName); + if (object.docString != null) + message.docString = String(object.docString); + switch (object.type) { + default: + if (typeof object.type === "number") { + message.type = object.type; + break; + } + break; + case "UNDEFINED": + case 0: + message.type = 0; + break; + case "FLOAT": + case 1: + message.type = 1; + break; + case "INT": + case 2: + message.type = 2; + break; + case "STRING": + case 3: + message.type = 3; + break; + case "TENSOR": + case 4: + message.type = 4; + break; + case "GRAPH": + case 5: + message.type = 5; + break; + case "SPARSE_TENSOR": + case 11: + message.type = 11; + break; + case "TYPE_PROTO": + case 13: + message.type = 13; + break; + case "FLOATS": + case 6: + message.type = 6; + break; + case "INTS": + case 7: + message.type = 7; + break; + case "STRINGS": + case 8: + message.type = 8; + break; + case "TENSORS": + case 9: + message.type = 9; + break; + case "GRAPHS": + case 10: + message.type = 10; + break; + case "SPARSE_TENSORS": + case 12: + message.type = 12; + break; + case "TYPE_PROTOS": + case 14: + message.type = 14; + break; + } + if (object.f != null) + message.f = Number(object.f); + if (object.i != null) + if ($util.Long) + (message.i = $util.Long.fromValue(object.i)).unsigned = false; + else if (typeof object.i === "string") + message.i = parseInt(object.i, 10); + else if (typeof object.i === "number") + message.i = object.i; + else if (typeof object.i === "object") + message.i = new $util.LongBits(object.i.low >>> 0, object.i.high >>> 0).toNumber(); + if (object.s != null) + if (typeof object.s === "string") + $util.base64.decode(object.s, message.s = $util.newBuffer($util.base64.length(object.s)), 0); + else if (object.s.length >= 0) + message.s = object.s; + if (object.t != null) { + if (typeof object.t !== "object") + throw TypeError(".onnx.AttributeProto.t: object expected"); + message.t = $root.onnx.TensorProto.fromObject(object.t); + } + if (object.g != null) { + if (typeof object.g !== "object") + throw TypeError(".onnx.AttributeProto.g: object expected"); + message.g = $root.onnx.GraphProto.fromObject(object.g); + } + if (object.sparseTensor != null) { + if (typeof object.sparseTensor !== "object") + throw TypeError(".onnx.AttributeProto.sparseTensor: object expected"); + message.sparseTensor = $root.onnx.SparseTensorProto.fromObject(object.sparseTensor); + } + if (object.tp != null) { + if (typeof object.tp !== "object") + throw TypeError(".onnx.AttributeProto.tp: object expected"); + message.tp = $root.onnx.TypeProto.fromObject(object.tp); + } + if (object.floats) { + if (!Array.isArray(object.floats)) + throw TypeError(".onnx.AttributeProto.floats: array expected"); + message.floats = []; + for (var i = 0; i < object.floats.length; ++i) + message.floats[i] = Number(object.floats[i]); + } + if (object.ints) { + if (!Array.isArray(object.ints)) + throw TypeError(".onnx.AttributeProto.ints: array expected"); + message.ints = []; + for (var i = 0; i < object.ints.length; ++i) + if ($util.Long) + (message.ints[i] = $util.Long.fromValue(object.ints[i])).unsigned = false; + else if (typeof object.ints[i] === "string") + message.ints[i] = parseInt(object.ints[i], 10); + else if (typeof object.ints[i] === "number") + message.ints[i] = object.ints[i]; + else if (typeof object.ints[i] === "object") + message.ints[i] = new $util.LongBits(object.ints[i].low >>> 0, object.ints[i].high >>> 0).toNumber(); + } + if (object.strings) { + if (!Array.isArray(object.strings)) + throw TypeError(".onnx.AttributeProto.strings: array expected"); + message.strings = []; + for (var i = 0; i < object.strings.length; ++i) + if (typeof object.strings[i] === "string") + $util.base64.decode(object.strings[i], message.strings[i] = $util.newBuffer($util.base64.length(object.strings[i])), 0); + else if (object.strings[i].length >= 0) + message.strings[i] = object.strings[i]; + } + if (object.tensors) { + if (!Array.isArray(object.tensors)) + throw TypeError(".onnx.AttributeProto.tensors: array expected"); + message.tensors = []; + for (var i = 0; i < object.tensors.length; ++i) { + if (typeof object.tensors[i] !== "object") + throw TypeError(".onnx.AttributeProto.tensors: object expected"); + message.tensors[i] = $root.onnx.TensorProto.fromObject(object.tensors[i]); + } + } + if (object.graphs) { + if (!Array.isArray(object.graphs)) + throw TypeError(".onnx.AttributeProto.graphs: array expected"); + message.graphs = []; + for (var i = 0; i < object.graphs.length; ++i) { + if (typeof object.graphs[i] !== "object") + throw TypeError(".onnx.AttributeProto.graphs: object expected"); + message.graphs[i] = $root.onnx.GraphProto.fromObject(object.graphs[i]); + } + } + if (object.sparseTensors) { + if (!Array.isArray(object.sparseTensors)) + throw TypeError(".onnx.AttributeProto.sparseTensors: array expected"); + message.sparseTensors = []; + for (var i = 0; i < object.sparseTensors.length; ++i) { + if (typeof object.sparseTensors[i] !== "object") + throw TypeError(".onnx.AttributeProto.sparseTensors: object expected"); + message.sparseTensors[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseTensors[i]); + } + } + if (object.typeProtos) { + if (!Array.isArray(object.typeProtos)) + throw TypeError(".onnx.AttributeProto.typeProtos: array expected"); + message.typeProtos = []; + for (var i = 0; i < object.typeProtos.length; ++i) { + if (typeof object.typeProtos[i] !== "object") + throw TypeError(".onnx.AttributeProto.typeProtos: object expected"); + message.typeProtos[i] = $root.onnx.TypeProto.fromObject(object.typeProtos[i]); + } + } + return message; + }; + + /** + * Creates a plain object from an AttributeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.AttributeProto + * @static + * @param {onnx.AttributeProto} message AttributeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + AttributeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.floats = []; + object.ints = []; + object.strings = []; + object.tensors = []; + object.graphs = []; + object.typeProtos = []; + object.sparseTensors = []; + } + if (options.defaults) { + object.name = ""; + object.f = 0; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.i = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.i = options.longs === String ? "0" : 0; + if (options.bytes === String) + object.s = ""; + else { + object.s = []; + if (options.bytes !== Array) + object.s = $util.newBuffer(object.s); + } + object.t = null; + object.g = null; + object.docString = ""; + object.tp = null; + object.type = options.enums === String ? "UNDEFINED" : 0; + object.refAttrName = ""; + object.sparseTensor = null; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.f != null && message.hasOwnProperty("f")) + object.f = options.json && !isFinite(message.f) ? String(message.f) : message.f; + if (message.i != null && message.hasOwnProperty("i")) + if (typeof message.i === "number") + object.i = options.longs === String ? String(message.i) : message.i; + else + object.i = options.longs === String ? $util.Long.prototype.toString.call(message.i) : options.longs === Number ? new $util.LongBits(message.i.low >>> 0, message.i.high >>> 0).toNumber() : message.i; + if (message.s != null && message.hasOwnProperty("s")) + object.s = options.bytes === String ? $util.base64.encode(message.s, 0, message.s.length) : options.bytes === Array ? Array.prototype.slice.call(message.s) : message.s; + if (message.t != null && message.hasOwnProperty("t")) + object.t = $root.onnx.TensorProto.toObject(message.t, options); + if (message.g != null && message.hasOwnProperty("g")) + object.g = $root.onnx.GraphProto.toObject(message.g, options); + if (message.floats && message.floats.length) { + object.floats = []; + for (var j = 0; j < message.floats.length; ++j) + object.floats[j] = options.json && !isFinite(message.floats[j]) ? String(message.floats[j]) : message.floats[j]; + } + if (message.ints && message.ints.length) { + object.ints = []; + for (var j = 0; j < message.ints.length; ++j) + if (typeof message.ints[j] === "number") + object.ints[j] = options.longs === String ? String(message.ints[j]) : message.ints[j]; + else + object.ints[j] = options.longs === String ? $util.Long.prototype.toString.call(message.ints[j]) : options.longs === Number ? new $util.LongBits(message.ints[j].low >>> 0, message.ints[j].high >>> 0).toNumber() : message.ints[j]; + } + if (message.strings && message.strings.length) { + object.strings = []; + for (var j = 0; j < message.strings.length; ++j) + object.strings[j] = options.bytes === String ? $util.base64.encode(message.strings[j], 0, message.strings[j].length) : options.bytes === Array ? Array.prototype.slice.call(message.strings[j]) : message.strings[j]; + } + if (message.tensors && message.tensors.length) { + object.tensors = []; + for (var j = 0; j < message.tensors.length; ++j) + object.tensors[j] = $root.onnx.TensorProto.toObject(message.tensors[j], options); + } + if (message.graphs && message.graphs.length) { + object.graphs = []; + for (var j = 0; j < message.graphs.length; ++j) + object.graphs[j] = $root.onnx.GraphProto.toObject(message.graphs[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.tp != null && message.hasOwnProperty("tp")) + object.tp = $root.onnx.TypeProto.toObject(message.tp, options); + if (message.typeProtos && message.typeProtos.length) { + object.typeProtos = []; + for (var j = 0; j < message.typeProtos.length; ++j) + object.typeProtos[j] = $root.onnx.TypeProto.toObject(message.typeProtos[j], options); + } + if (message.type != null && message.hasOwnProperty("type")) + object.type = options.enums === String ? $root.onnx.AttributeProto.AttributeType[message.type] === undefined ? message.type : $root.onnx.AttributeProto.AttributeType[message.type] : message.type; + if (message.refAttrName != null && message.hasOwnProperty("refAttrName")) + object.refAttrName = message.refAttrName; + if (message.sparseTensor != null && message.hasOwnProperty("sparseTensor")) + object.sparseTensor = $root.onnx.SparseTensorProto.toObject(message.sparseTensor, options); + if (message.sparseTensors && message.sparseTensors.length) { + object.sparseTensors = []; + for (var j = 0; j < message.sparseTensors.length; ++j) + object.sparseTensors[j] = $root.onnx.SparseTensorProto.toObject(message.sparseTensors[j], options); + } + return object; + }; + + /** + * Converts this AttributeProto to JSON. + * @function toJSON + * @memberof onnx.AttributeProto + * @instance + * @returns {Object.} JSON object + */ + AttributeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for AttributeProto + * @function getTypeUrl + * @memberof onnx.AttributeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + AttributeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.AttributeProto"; + }; + + /** + * AttributeType enum. + * @name onnx.AttributeProto.AttributeType + * @enum {number} + * @property {number} UNDEFINED=0 UNDEFINED value + * @property {number} FLOAT=1 FLOAT value + * @property {number} INT=2 INT value + * @property {number} STRING=3 STRING value + * @property {number} TENSOR=4 TENSOR value + * @property {number} GRAPH=5 GRAPH value + * @property {number} SPARSE_TENSOR=11 SPARSE_TENSOR value + * @property {number} TYPE_PROTO=13 TYPE_PROTO value + * @property {number} FLOATS=6 FLOATS value + * @property {number} INTS=7 INTS value + * @property {number} STRINGS=8 STRINGS value + * @property {number} TENSORS=9 TENSORS value + * @property {number} GRAPHS=10 GRAPHS value + * @property {number} SPARSE_TENSORS=12 SPARSE_TENSORS value + * @property {number} TYPE_PROTOS=14 TYPE_PROTOS value + */ + AttributeProto.AttributeType = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "UNDEFINED"] = 0; + values[valuesById[1] = "FLOAT"] = 1; + values[valuesById[2] = "INT"] = 2; + values[valuesById[3] = "STRING"] = 3; + values[valuesById[4] = "TENSOR"] = 4; + values[valuesById[5] = "GRAPH"] = 5; + values[valuesById[11] = "SPARSE_TENSOR"] = 11; + values[valuesById[13] = "TYPE_PROTO"] = 13; + values[valuesById[6] = "FLOATS"] = 6; + values[valuesById[7] = "INTS"] = 7; + values[valuesById[8] = "STRINGS"] = 8; + values[valuesById[9] = "TENSORS"] = 9; + values[valuesById[10] = "GRAPHS"] = 10; + values[valuesById[12] = "SPARSE_TENSORS"] = 12; + values[valuesById[14] = "TYPE_PROTOS"] = 14; + return values; + })(); + + return AttributeProto; + })(); + + onnx.ValueInfoProto = (function() { + + /** + * Properties of a ValueInfoProto. + * @memberof onnx + * @interface IValueInfoProto + * @property {string|null} [name] ValueInfoProto name + * @property {onnx.ITypeProto|null} [type] ValueInfoProto type + * @property {string|null} [docString] ValueInfoProto docString + */ + + /** + * Constructs a new ValueInfoProto. + * @memberof onnx + * @classdesc Represents a ValueInfoProto. + * @implements IValueInfoProto + * @constructor + * @param {onnx.IValueInfoProto=} [properties] Properties to set + */ + function ValueInfoProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * ValueInfoProto name. + * @member {string} name + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.name = ""; + + /** + * ValueInfoProto type. + * @member {onnx.ITypeProto|null|undefined} type + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.type = null; + + /** + * ValueInfoProto docString. + * @member {string} docString + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.docString = ""; + + /** + * Creates a new ValueInfoProto instance using the specified properties. + * @function create + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto=} [properties] Properties to set + * @returns {onnx.ValueInfoProto} ValueInfoProto instance + */ + ValueInfoProto.create = function create(properties) { + return new ValueInfoProto(properties); + }; + + /** + * Encodes the specified ValueInfoProto message. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. + * @function encode + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ValueInfoProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); + if (message.type != null && Object.hasOwnProperty.call(message, "type")) + $root.onnx.TypeProto.encode(message.type, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.docString); + return writer; + }; + + /** + * Encodes the specified ValueInfoProto message, length delimited. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ValueInfoProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.ValueInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.ValueInfoProto} ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ValueInfoProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.ValueInfoProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 2: { + message.type = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + case 3: { + message.docString = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.ValueInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.ValueInfoProto} ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ValueInfoProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a ValueInfoProto message. + * @function verify + * @memberof onnx.ValueInfoProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + ValueInfoProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.type != null && message.hasOwnProperty("type")) { + var error = $root.onnx.TypeProto.verify(message.type); + if (error) + return "type." + error; + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + return null; + }; + + /** + * Creates a ValueInfoProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.ValueInfoProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.ValueInfoProto} ValueInfoProto + */ + ValueInfoProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.ValueInfoProto) + return object; + var message = new $root.onnx.ValueInfoProto(); + if (object.name != null) + message.name = String(object.name); + if (object.type != null) { + if (typeof object.type !== "object") + throw TypeError(".onnx.ValueInfoProto.type: object expected"); + message.type = $root.onnx.TypeProto.fromObject(object.type); + } + if (object.docString != null) + message.docString = String(object.docString); + return message; + }; + + /** + * Creates a plain object from a ValueInfoProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.ValueInfoProto} message ValueInfoProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + ValueInfoProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.name = ""; + object.type = null; + object.docString = ""; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.type != null && message.hasOwnProperty("type")) + object.type = $root.onnx.TypeProto.toObject(message.type, options); + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + return object; + }; + + /** + * Converts this ValueInfoProto to JSON. + * @function toJSON + * @memberof onnx.ValueInfoProto + * @instance + * @returns {Object.} JSON object + */ + ValueInfoProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for ValueInfoProto + * @function getTypeUrl + * @memberof onnx.ValueInfoProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + ValueInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.ValueInfoProto"; + }; + + return ValueInfoProto; + })(); + + onnx.NodeProto = (function() { + + /** + * Properties of a NodeProto. + * @memberof onnx + * @interface INodeProto + * @property {Array.|null} [input] NodeProto input + * @property {Array.|null} [output] NodeProto output + * @property {string|null} [name] NodeProto name + * @property {string|null} [opType] NodeProto opType + * @property {string|null} [domain] NodeProto domain + * @property {Array.|null} [attribute] NodeProto attribute + * @property {string|null} [docString] NodeProto docString + */ + + /** + * Constructs a new NodeProto. + * @memberof onnx + * @classdesc Represents a NodeProto. + * @implements INodeProto + * @constructor + * @param {onnx.INodeProto=} [properties] Properties to set + */ + function NodeProto(properties) { + this.input = []; + this.output = []; + this.attribute = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * NodeProto input. + * @member {Array.} input + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.input = $util.emptyArray; + + /** + * NodeProto output. + * @member {Array.} output + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.output = $util.emptyArray; + + /** + * NodeProto name. + * @member {string} name + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.name = ""; + + /** + * NodeProto opType. + * @member {string} opType + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.opType = ""; + + /** + * NodeProto domain. + * @member {string} domain + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.domain = ""; + + /** + * NodeProto attribute. + * @member {Array.} attribute + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.attribute = $util.emptyArray; + + /** + * NodeProto docString. + * @member {string} docString + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.docString = ""; + + /** + * Creates a new NodeProto instance using the specified properties. + * @function create + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto=} [properties] Properties to set + * @returns {onnx.NodeProto} NodeProto instance + */ + NodeProto.create = function create(properties) { + return new NodeProto(properties); + }; + + /** + * Encodes the specified NodeProto message. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @function encode + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto} message NodeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + NodeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.input[i]); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.output[i]); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.name); + if (message.opType != null && Object.hasOwnProperty.call(message, "opType")) + writer.uint32(/* id 4, wireType 2 =*/34).string(message.opType); + if (message.attribute != null && message.attribute.length) + for (var i = 0; i < message.attribute.length; ++i) + $root.onnx.AttributeProto.encode(message.attribute[i], writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.docString); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 7, wireType 2 =*/58).string(message.domain); + return writer; + }; + + /** + * Encodes the specified NodeProto message, length delimited. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto} message NodeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + NodeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a NodeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.NodeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.NodeProto} NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + NodeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.NodeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.input && message.input.length)) + message.input = []; + message.input.push(reader.string()); + break; + } + case 2: { + if (!(message.output && message.output.length)) + message.output = []; + message.output.push(reader.string()); + break; + } + case 3: { + message.name = reader.string(); + break; + } + case 4: { + message.opType = reader.string(); + break; + } + case 7: { + message.domain = reader.string(); + break; + } + case 5: { + if (!(message.attribute && message.attribute.length)) + message.attribute = []; + message.attribute.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); + break; + } + case 6: { + message.docString = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a NodeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.NodeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.NodeProto} NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + NodeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a NodeProto message. + * @function verify + * @memberof onnx.NodeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + NodeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) + return "input: array expected"; + for (var i = 0; i < message.input.length; ++i) + if (!$util.isString(message.input[i])) + return "input: string[] expected"; + } + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) + return "output: array expected"; + for (var i = 0; i < message.output.length; ++i) + if (!$util.isString(message.output[i])) + return "output: string[] expected"; + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.opType != null && message.hasOwnProperty("opType")) + if (!$util.isString(message.opType)) + return "opType: string expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + if (message.attribute != null && message.hasOwnProperty("attribute")) { + if (!Array.isArray(message.attribute)) + return "attribute: array expected"; + for (var i = 0; i < message.attribute.length; ++i) { + var error = $root.onnx.AttributeProto.verify(message.attribute[i]); + if (error) + return "attribute." + error; + } + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + return null; + }; + + /** + * Creates a NodeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.NodeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.NodeProto} NodeProto + */ + NodeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.NodeProto) + return object; + var message = new $root.onnx.NodeProto(); + if (object.input) { + if (!Array.isArray(object.input)) + throw TypeError(".onnx.NodeProto.input: array expected"); + message.input = []; + for (var i = 0; i < object.input.length; ++i) + message.input[i] = String(object.input[i]); + } + if (object.output) { + if (!Array.isArray(object.output)) + throw TypeError(".onnx.NodeProto.output: array expected"); + message.output = []; + for (var i = 0; i < object.output.length; ++i) + message.output[i] = String(object.output[i]); + } + if (object.name != null) + message.name = String(object.name); + if (object.opType != null) + message.opType = String(object.opType); + if (object.domain != null) + message.domain = String(object.domain); + if (object.attribute) { + if (!Array.isArray(object.attribute)) + throw TypeError(".onnx.NodeProto.attribute: array expected"); + message.attribute = []; + for (var i = 0; i < object.attribute.length; ++i) { + if (typeof object.attribute[i] !== "object") + throw TypeError(".onnx.NodeProto.attribute: object expected"); + message.attribute[i] = $root.onnx.AttributeProto.fromObject(object.attribute[i]); + } + } + if (object.docString != null) + message.docString = String(object.docString); + return message; + }; + + /** + * Creates a plain object from a NodeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.NodeProto + * @static + * @param {onnx.NodeProto} message NodeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + NodeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.input = []; + object.output = []; + object.attribute = []; + } + if (options.defaults) { + object.name = ""; + object.opType = ""; + object.docString = ""; + object.domain = ""; + } + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = message.input[j]; + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = message.output[j]; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.opType != null && message.hasOwnProperty("opType")) + object.opType = message.opType; + if (message.attribute && message.attribute.length) { + object.attribute = []; + for (var j = 0; j < message.attribute.length; ++j) + object.attribute[j] = $root.onnx.AttributeProto.toObject(message.attribute[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + return object; + }; + + /** + * Converts this NodeProto to JSON. + * @function toJSON + * @memberof onnx.NodeProto + * @instance + * @returns {Object.} JSON object + */ + NodeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for NodeProto + * @function getTypeUrl + * @memberof onnx.NodeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + NodeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.NodeProto"; + }; + + return NodeProto; + })(); + + onnx.TrainingInfoProto = (function() { + + /** + * Properties of a TrainingInfoProto. + * @memberof onnx + * @interface ITrainingInfoProto + * @property {onnx.IGraphProto|null} [initialization] TrainingInfoProto initialization + * @property {onnx.IGraphProto|null} [algorithm] TrainingInfoProto algorithm + * @property {Array.|null} [initializationBinding] TrainingInfoProto initializationBinding + * @property {Array.|null} [updateBinding] TrainingInfoProto updateBinding + */ + + /** + * Constructs a new TrainingInfoProto. + * @memberof onnx + * @classdesc Represents a TrainingInfoProto. + * @implements ITrainingInfoProto + * @constructor + * @param {onnx.ITrainingInfoProto=} [properties] Properties to set + */ + function TrainingInfoProto(properties) { + this.initializationBinding = []; + this.updateBinding = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TrainingInfoProto initialization. + * @member {onnx.IGraphProto|null|undefined} initialization + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.initialization = null; + + /** + * TrainingInfoProto algorithm. + * @member {onnx.IGraphProto|null|undefined} algorithm + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.algorithm = null; + + /** + * TrainingInfoProto initializationBinding. + * @member {Array.} initializationBinding + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.initializationBinding = $util.emptyArray; + + /** + * TrainingInfoProto updateBinding. + * @member {Array.} updateBinding + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.updateBinding = $util.emptyArray; + + /** + * Creates a new TrainingInfoProto instance using the specified properties. + * @function create + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto=} [properties] Properties to set + * @returns {onnx.TrainingInfoProto} TrainingInfoProto instance + */ + TrainingInfoProto.create = function create(properties) { + return new TrainingInfoProto(properties); + }; + + /** + * Encodes the specified TrainingInfoProto message. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. + * @function encode + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TrainingInfoProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.initialization != null && Object.hasOwnProperty.call(message, "initialization")) + $root.onnx.GraphProto.encode(message.initialization, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.algorithm != null && Object.hasOwnProperty.call(message, "algorithm")) + $root.onnx.GraphProto.encode(message.algorithm, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.initializationBinding != null && message.initializationBinding.length) + for (var i = 0; i < message.initializationBinding.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.initializationBinding[i], writer.uint32(/* id 3, wireType 2 =*/26).fork()).ldelim(); + if (message.updateBinding != null && message.updateBinding.length) + for (var i = 0; i < message.updateBinding.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.updateBinding[i], writer.uint32(/* id 4, wireType 2 =*/34).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TrainingInfoProto message, length delimited. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TrainingInfoProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TrainingInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TrainingInfoProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TrainingInfoProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.initialization = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 2: { + message.algorithm = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 3: { + if (!(message.initializationBinding && message.initializationBinding.length)) + message.initializationBinding = []; + message.initializationBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 4: { + if (!(message.updateBinding && message.updateBinding.length)) + message.updateBinding = []; + message.updateBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TrainingInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TrainingInfoProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TrainingInfoProto message. + * @function verify + * @memberof onnx.TrainingInfoProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TrainingInfoProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.initialization != null && message.hasOwnProperty("initialization")) { + var error = $root.onnx.GraphProto.verify(message.initialization); + if (error) + return "initialization." + error; + } + if (message.algorithm != null && message.hasOwnProperty("algorithm")) { + var error = $root.onnx.GraphProto.verify(message.algorithm); + if (error) + return "algorithm." + error; + } + if (message.initializationBinding != null && message.hasOwnProperty("initializationBinding")) { + if (!Array.isArray(message.initializationBinding)) + return "initializationBinding: array expected"; + for (var i = 0; i < message.initializationBinding.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.initializationBinding[i]); + if (error) + return "initializationBinding." + error; + } + } + if (message.updateBinding != null && message.hasOwnProperty("updateBinding")) { + if (!Array.isArray(message.updateBinding)) + return "updateBinding: array expected"; + for (var i = 0; i < message.updateBinding.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.updateBinding[i]); + if (error) + return "updateBinding." + error; + } + } + return null; + }; + + /** + * Creates a TrainingInfoProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TrainingInfoProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + */ + TrainingInfoProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TrainingInfoProto) + return object; + var message = new $root.onnx.TrainingInfoProto(); + if (object.initialization != null) { + if (typeof object.initialization !== "object") + throw TypeError(".onnx.TrainingInfoProto.initialization: object expected"); + message.initialization = $root.onnx.GraphProto.fromObject(object.initialization); + } + if (object.algorithm != null) { + if (typeof object.algorithm !== "object") + throw TypeError(".onnx.TrainingInfoProto.algorithm: object expected"); + message.algorithm = $root.onnx.GraphProto.fromObject(object.algorithm); + } + if (object.initializationBinding) { + if (!Array.isArray(object.initializationBinding)) + throw TypeError(".onnx.TrainingInfoProto.initializationBinding: array expected"); + message.initializationBinding = []; + for (var i = 0; i < object.initializationBinding.length; ++i) { + if (typeof object.initializationBinding[i] !== "object") + throw TypeError(".onnx.TrainingInfoProto.initializationBinding: object expected"); + message.initializationBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.initializationBinding[i]); + } + } + if (object.updateBinding) { + if (!Array.isArray(object.updateBinding)) + throw TypeError(".onnx.TrainingInfoProto.updateBinding: array expected"); + message.updateBinding = []; + for (var i = 0; i < object.updateBinding.length; ++i) { + if (typeof object.updateBinding[i] !== "object") + throw TypeError(".onnx.TrainingInfoProto.updateBinding: object expected"); + message.updateBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.updateBinding[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TrainingInfoProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.TrainingInfoProto} message TrainingInfoProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TrainingInfoProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.initializationBinding = []; + object.updateBinding = []; + } + if (options.defaults) { + object.initialization = null; + object.algorithm = null; + } + if (message.initialization != null && message.hasOwnProperty("initialization")) + object.initialization = $root.onnx.GraphProto.toObject(message.initialization, options); + if (message.algorithm != null && message.hasOwnProperty("algorithm")) + object.algorithm = $root.onnx.GraphProto.toObject(message.algorithm, options); + if (message.initializationBinding && message.initializationBinding.length) { + object.initializationBinding = []; + for (var j = 0; j < message.initializationBinding.length; ++j) + object.initializationBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.initializationBinding[j], options); + } + if (message.updateBinding && message.updateBinding.length) { + object.updateBinding = []; + for (var j = 0; j < message.updateBinding.length; ++j) + object.updateBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.updateBinding[j], options); + } + return object; + }; + + /** + * Converts this TrainingInfoProto to JSON. + * @function toJSON + * @memberof onnx.TrainingInfoProto + * @instance + * @returns {Object.} JSON object + */ + TrainingInfoProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TrainingInfoProto + * @function getTypeUrl + * @memberof onnx.TrainingInfoProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TrainingInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TrainingInfoProto"; + }; + + return TrainingInfoProto; + })(); + + onnx.ModelProto = (function() { + + /** + * Properties of a ModelProto. + * @memberof onnx + * @interface IModelProto + * @property {number|Long|null} [irVersion] ModelProto irVersion + * @property {Array.|null} [opsetImport] ModelProto opsetImport + * @property {string|null} [producerName] ModelProto producerName + * @property {string|null} [producerVersion] ModelProto producerVersion + * @property {string|null} [domain] ModelProto domain + * @property {number|Long|null} [modelVersion] ModelProto modelVersion + * @property {string|null} [docString] ModelProto docString + * @property {onnx.IGraphProto|null} [graph] ModelProto graph + * @property {Array.|null} [metadataProps] ModelProto metadataProps + * @property {Array.|null} [trainingInfo] ModelProto trainingInfo + * @property {Array.|null} [functions] ModelProto functions + */ + + /** + * Constructs a new ModelProto. + * @memberof onnx + * @classdesc Represents a ModelProto. + * @implements IModelProto + * @constructor + * @param {onnx.IModelProto=} [properties] Properties to set + */ + function ModelProto(properties) { + this.opsetImport = []; + this.metadataProps = []; + this.trainingInfo = []; + this.functions = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * ModelProto irVersion. + * @member {number|Long} irVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.irVersion = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * ModelProto opsetImport. + * @member {Array.} opsetImport + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.opsetImport = $util.emptyArray; + + /** + * ModelProto producerName. + * @member {string} producerName + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.producerName = ""; + + /** + * ModelProto producerVersion. + * @member {string} producerVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.producerVersion = ""; + + /** + * ModelProto domain. + * @member {string} domain + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.domain = ""; + + /** + * ModelProto modelVersion. + * @member {number|Long} modelVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.modelVersion = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * ModelProto docString. + * @member {string} docString + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.docString = ""; + + /** + * ModelProto graph. + * @member {onnx.IGraphProto|null|undefined} graph + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.graph = null; + + /** + * ModelProto metadataProps. + * @member {Array.} metadataProps + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.metadataProps = $util.emptyArray; + + /** + * ModelProto trainingInfo. + * @member {Array.} trainingInfo + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.trainingInfo = $util.emptyArray; + + /** + * ModelProto functions. + * @member {Array.} functions + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.functions = $util.emptyArray; + + /** + * Creates a new ModelProto instance using the specified properties. + * @function create + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto=} [properties] Properties to set + * @returns {onnx.ModelProto} ModelProto instance + */ + ModelProto.create = function create(properties) { + return new ModelProto(properties); + }; + + /** + * Encodes the specified ModelProto message. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @function encode + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto} message ModelProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ModelProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.irVersion != null && Object.hasOwnProperty.call(message, "irVersion")) + writer.uint32(/* id 1, wireType 0 =*/8).int64(message.irVersion); + if (message.producerName != null && Object.hasOwnProperty.call(message, "producerName")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.producerName); + if (message.producerVersion != null && Object.hasOwnProperty.call(message, "producerVersion")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.producerVersion); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 4, wireType 2 =*/34).string(message.domain); + if (message.modelVersion != null && Object.hasOwnProperty.call(message, "modelVersion")) + writer.uint32(/* id 5, wireType 0 =*/40).int64(message.modelVersion); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.docString); + if (message.graph != null && Object.hasOwnProperty.call(message, "graph")) + $root.onnx.GraphProto.encode(message.graph, writer.uint32(/* id 7, wireType 2 =*/58).fork()).ldelim(); + if (message.opsetImport != null && message.opsetImport.length) + for (var i = 0; i < message.opsetImport.length; ++i) + $root.onnx.OperatorSetIdProto.encode(message.opsetImport[i], writer.uint32(/* id 8, wireType 2 =*/66).fork()).ldelim(); + if (message.metadataProps != null && message.metadataProps.length) + for (var i = 0; i < message.metadataProps.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.metadataProps[i], writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); + if (message.trainingInfo != null && message.trainingInfo.length) + for (var i = 0; i < message.trainingInfo.length; ++i) + $root.onnx.TrainingInfoProto.encode(message.trainingInfo[i], writer.uint32(/* id 20, wireType 2 =*/162).fork()).ldelim(); + if (message.functions != null && message.functions.length) + for (var i = 0; i < message.functions.length; ++i) + $root.onnx.FunctionProto.encode(message.functions[i], writer.uint32(/* id 25, wireType 2 =*/202).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified ModelProto message, length delimited. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto} message ModelProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ModelProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a ModelProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.ModelProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.ModelProto} ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ModelProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.ModelProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.irVersion = reader.int64(); + break; + } + case 8: { + if (!(message.opsetImport && message.opsetImport.length)) + message.opsetImport = []; + message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); + break; + } + case 2: { + message.producerName = reader.string(); + break; + } + case 3: { + message.producerVersion = reader.string(); + break; + } + case 4: { + message.domain = reader.string(); + break; + } + case 5: { + message.modelVersion = reader.int64(); + break; + } + case 6: { + message.docString = reader.string(); + break; + } + case 7: { + message.graph = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 14: { + if (!(message.metadataProps && message.metadataProps.length)) + message.metadataProps = []; + message.metadataProps.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 20: { + if (!(message.trainingInfo && message.trainingInfo.length)) + message.trainingInfo = []; + message.trainingInfo.push($root.onnx.TrainingInfoProto.decode(reader, reader.uint32())); + break; + } + case 25: { + if (!(message.functions && message.functions.length)) + message.functions = []; + message.functions.push($root.onnx.FunctionProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a ModelProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.ModelProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.ModelProto} ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ModelProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a ModelProto message. + * @function verify + * @memberof onnx.ModelProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + ModelProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.irVersion != null && message.hasOwnProperty("irVersion")) + if (!$util.isInteger(message.irVersion) && !(message.irVersion && $util.isInteger(message.irVersion.low) && $util.isInteger(message.irVersion.high))) + return "irVersion: integer|Long expected"; + if (message.opsetImport != null && message.hasOwnProperty("opsetImport")) { + if (!Array.isArray(message.opsetImport)) + return "opsetImport: array expected"; + for (var i = 0; i < message.opsetImport.length; ++i) { + var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); + if (error) + return "opsetImport." + error; + } + } + if (message.producerName != null && message.hasOwnProperty("producerName")) + if (!$util.isString(message.producerName)) + return "producerName: string expected"; + if (message.producerVersion != null && message.hasOwnProperty("producerVersion")) + if (!$util.isString(message.producerVersion)) + return "producerVersion: string expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + if (message.modelVersion != null && message.hasOwnProperty("modelVersion")) + if (!$util.isInteger(message.modelVersion) && !(message.modelVersion && $util.isInteger(message.modelVersion.low) && $util.isInteger(message.modelVersion.high))) + return "modelVersion: integer|Long expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.graph != null && message.hasOwnProperty("graph")) { + var error = $root.onnx.GraphProto.verify(message.graph); + if (error) + return "graph." + error; + } + if (message.metadataProps != null && message.hasOwnProperty("metadataProps")) { + if (!Array.isArray(message.metadataProps)) + return "metadataProps: array expected"; + for (var i = 0; i < message.metadataProps.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.metadataProps[i]); + if (error) + return "metadataProps." + error; + } + } + if (message.trainingInfo != null && message.hasOwnProperty("trainingInfo")) { + if (!Array.isArray(message.trainingInfo)) + return "trainingInfo: array expected"; + for (var i = 0; i < message.trainingInfo.length; ++i) { + var error = $root.onnx.TrainingInfoProto.verify(message.trainingInfo[i]); + if (error) + return "trainingInfo." + error; + } + } + if (message.functions != null && message.hasOwnProperty("functions")) { + if (!Array.isArray(message.functions)) + return "functions: array expected"; + for (var i = 0; i < message.functions.length; ++i) { + var error = $root.onnx.FunctionProto.verify(message.functions[i]); + if (error) + return "functions." + error; + } + } + return null; + }; + + /** + * Creates a ModelProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.ModelProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.ModelProto} ModelProto + */ + ModelProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.ModelProto) + return object; + var message = new $root.onnx.ModelProto(); + if (object.irVersion != null) + if ($util.Long) + (message.irVersion = $util.Long.fromValue(object.irVersion)).unsigned = false; + else if (typeof object.irVersion === "string") + message.irVersion = parseInt(object.irVersion, 10); + else if (typeof object.irVersion === "number") + message.irVersion = object.irVersion; + else if (typeof object.irVersion === "object") + message.irVersion = new $util.LongBits(object.irVersion.low >>> 0, object.irVersion.high >>> 0).toNumber(); + if (object.opsetImport) { + if (!Array.isArray(object.opsetImport)) + throw TypeError(".onnx.ModelProto.opsetImport: array expected"); + message.opsetImport = []; + for (var i = 0; i < object.opsetImport.length; ++i) { + if (typeof object.opsetImport[i] !== "object") + throw TypeError(".onnx.ModelProto.opsetImport: object expected"); + message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); + } + } + if (object.producerName != null) + message.producerName = String(object.producerName); + if (object.producerVersion != null) + message.producerVersion = String(object.producerVersion); + if (object.domain != null) + message.domain = String(object.domain); + if (object.modelVersion != null) + if ($util.Long) + (message.modelVersion = $util.Long.fromValue(object.modelVersion)).unsigned = false; + else if (typeof object.modelVersion === "string") + message.modelVersion = parseInt(object.modelVersion, 10); + else if (typeof object.modelVersion === "number") + message.modelVersion = object.modelVersion; + else if (typeof object.modelVersion === "object") + message.modelVersion = new $util.LongBits(object.modelVersion.low >>> 0, object.modelVersion.high >>> 0).toNumber(); + if (object.docString != null) + message.docString = String(object.docString); + if (object.graph != null) { + if (typeof object.graph !== "object") + throw TypeError(".onnx.ModelProto.graph: object expected"); + message.graph = $root.onnx.GraphProto.fromObject(object.graph); + } + if (object.metadataProps) { + if (!Array.isArray(object.metadataProps)) + throw TypeError(".onnx.ModelProto.metadataProps: array expected"); + message.metadataProps = []; + for (var i = 0; i < object.metadataProps.length; ++i) { + if (typeof object.metadataProps[i] !== "object") + throw TypeError(".onnx.ModelProto.metadataProps: object expected"); + message.metadataProps[i] = $root.onnx.StringStringEntryProto.fromObject(object.metadataProps[i]); + } + } + if (object.trainingInfo) { + if (!Array.isArray(object.trainingInfo)) + throw TypeError(".onnx.ModelProto.trainingInfo: array expected"); + message.trainingInfo = []; + for (var i = 0; i < object.trainingInfo.length; ++i) { + if (typeof object.trainingInfo[i] !== "object") + throw TypeError(".onnx.ModelProto.trainingInfo: object expected"); + message.trainingInfo[i] = $root.onnx.TrainingInfoProto.fromObject(object.trainingInfo[i]); + } + } + if (object.functions) { + if (!Array.isArray(object.functions)) + throw TypeError(".onnx.ModelProto.functions: array expected"); + message.functions = []; + for (var i = 0; i < object.functions.length; ++i) { + if (typeof object.functions[i] !== "object") + throw TypeError(".onnx.ModelProto.functions: object expected"); + message.functions[i] = $root.onnx.FunctionProto.fromObject(object.functions[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a ModelProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.ModelProto + * @static + * @param {onnx.ModelProto} message ModelProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + ModelProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.opsetImport = []; + object.metadataProps = []; + object.trainingInfo = []; + object.functions = []; + } + if (options.defaults) { + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.irVersion = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.irVersion = options.longs === String ? "0" : 0; + object.producerName = ""; + object.producerVersion = ""; + object.domain = ""; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.modelVersion = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.modelVersion = options.longs === String ? "0" : 0; + object.docString = ""; + object.graph = null; + } + if (message.irVersion != null && message.hasOwnProperty("irVersion")) + if (typeof message.irVersion === "number") + object.irVersion = options.longs === String ? String(message.irVersion) : message.irVersion; + else + object.irVersion = options.longs === String ? $util.Long.prototype.toString.call(message.irVersion) : options.longs === Number ? new $util.LongBits(message.irVersion.low >>> 0, message.irVersion.high >>> 0).toNumber() : message.irVersion; + if (message.producerName != null && message.hasOwnProperty("producerName")) + object.producerName = message.producerName; + if (message.producerVersion != null && message.hasOwnProperty("producerVersion")) + object.producerVersion = message.producerVersion; + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.modelVersion != null && message.hasOwnProperty("modelVersion")) + if (typeof message.modelVersion === "number") + object.modelVersion = options.longs === String ? String(message.modelVersion) : message.modelVersion; + else + object.modelVersion = options.longs === String ? $util.Long.prototype.toString.call(message.modelVersion) : options.longs === Number ? new $util.LongBits(message.modelVersion.low >>> 0, message.modelVersion.high >>> 0).toNumber() : message.modelVersion; + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.graph != null && message.hasOwnProperty("graph")) + object.graph = $root.onnx.GraphProto.toObject(message.graph, options); + if (message.opsetImport && message.opsetImport.length) { + object.opsetImport = []; + for (var j = 0; j < message.opsetImport.length; ++j) + object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); + } + if (message.metadataProps && message.metadataProps.length) { + object.metadataProps = []; + for (var j = 0; j < message.metadataProps.length; ++j) + object.metadataProps[j] = $root.onnx.StringStringEntryProto.toObject(message.metadataProps[j], options); + } + if (message.trainingInfo && message.trainingInfo.length) { + object.trainingInfo = []; + for (var j = 0; j < message.trainingInfo.length; ++j) + object.trainingInfo[j] = $root.onnx.TrainingInfoProto.toObject(message.trainingInfo[j], options); + } + if (message.functions && message.functions.length) { + object.functions = []; + for (var j = 0; j < message.functions.length; ++j) + object.functions[j] = $root.onnx.FunctionProto.toObject(message.functions[j], options); + } + return object; + }; + + /** + * Converts this ModelProto to JSON. + * @function toJSON + * @memberof onnx.ModelProto + * @instance + * @returns {Object.} JSON object + */ + ModelProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for ModelProto + * @function getTypeUrl + * @memberof onnx.ModelProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + ModelProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.ModelProto"; + }; + + return ModelProto; + })(); + + onnx.StringStringEntryProto = (function() { + + /** + * Properties of a StringStringEntryProto. + * @memberof onnx + * @interface IStringStringEntryProto + * @property {string|null} [key] StringStringEntryProto key + * @property {string|null} [value] StringStringEntryProto value + */ + + /** + * Constructs a new StringStringEntryProto. + * @memberof onnx + * @classdesc Represents a StringStringEntryProto. + * @implements IStringStringEntryProto + * @constructor + * @param {onnx.IStringStringEntryProto=} [properties] Properties to set + */ + function StringStringEntryProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * StringStringEntryProto key. + * @member {string} key + * @memberof onnx.StringStringEntryProto + * @instance + */ + StringStringEntryProto.prototype.key = ""; + + /** + * StringStringEntryProto value. + * @member {string} value + * @memberof onnx.StringStringEntryProto + * @instance + */ + StringStringEntryProto.prototype.value = ""; + + /** + * Creates a new StringStringEntryProto instance using the specified properties. + * @function create + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto=} [properties] Properties to set + * @returns {onnx.StringStringEntryProto} StringStringEntryProto instance + */ + StringStringEntryProto.create = function create(properties) { + return new StringStringEntryProto(properties); + }; + + /** + * Encodes the specified StringStringEntryProto message. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. + * @function encode + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + StringStringEntryProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.key != null && Object.hasOwnProperty.call(message, "key")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.key); + if (message.value != null && Object.hasOwnProperty.call(message, "value")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.value); + return writer; + }; + + /** + * Encodes the specified StringStringEntryProto message, length delimited. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + StringStringEntryProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.StringStringEntryProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + StringStringEntryProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.StringStringEntryProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.key = reader.string(); + break; + } + case 2: { + message.value = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.StringStringEntryProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + StringStringEntryProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a StringStringEntryProto message. + * @function verify + * @memberof onnx.StringStringEntryProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + StringStringEntryProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.key != null && message.hasOwnProperty("key")) + if (!$util.isString(message.key)) + return "key: string expected"; + if (message.value != null && message.hasOwnProperty("value")) + if (!$util.isString(message.value)) + return "value: string expected"; + return null; + }; + + /** + * Creates a StringStringEntryProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.StringStringEntryProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + */ + StringStringEntryProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.StringStringEntryProto) + return object; + var message = new $root.onnx.StringStringEntryProto(); + if (object.key != null) + message.key = String(object.key); + if (object.value != null) + message.value = String(object.value); + return message; + }; + + /** + * Creates a plain object from a StringStringEntryProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.StringStringEntryProto} message StringStringEntryProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + StringStringEntryProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.key = ""; + object.value = ""; + } + if (message.key != null && message.hasOwnProperty("key")) + object.key = message.key; + if (message.value != null && message.hasOwnProperty("value")) + object.value = message.value; + return object; + }; + + /** + * Converts this StringStringEntryProto to JSON. + * @function toJSON + * @memberof onnx.StringStringEntryProto + * @instance + * @returns {Object.} JSON object + */ + StringStringEntryProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for StringStringEntryProto + * @function getTypeUrl + * @memberof onnx.StringStringEntryProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + StringStringEntryProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.StringStringEntryProto"; + }; + + return StringStringEntryProto; + })(); + + onnx.TensorAnnotation = (function() { + + /** + * Properties of a TensorAnnotation. + * @memberof onnx + * @interface ITensorAnnotation + * @property {string|null} [tensorName] TensorAnnotation tensorName + * @property {Array.|null} [quantParameterTensorNames] TensorAnnotation quantParameterTensorNames + */ + + /** + * Constructs a new TensorAnnotation. + * @memberof onnx + * @classdesc Represents a TensorAnnotation. + * @implements ITensorAnnotation + * @constructor + * @param {onnx.ITensorAnnotation=} [properties] Properties to set + */ + function TensorAnnotation(properties) { + this.quantParameterTensorNames = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorAnnotation tensorName. + * @member {string} tensorName + * @memberof onnx.TensorAnnotation + * @instance + */ + TensorAnnotation.prototype.tensorName = ""; + + /** + * TensorAnnotation quantParameterTensorNames. + * @member {Array.} quantParameterTensorNames + * @memberof onnx.TensorAnnotation + * @instance + */ + TensorAnnotation.prototype.quantParameterTensorNames = $util.emptyArray; + + /** + * Creates a new TensorAnnotation instance using the specified properties. + * @function create + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation=} [properties] Properties to set + * @returns {onnx.TensorAnnotation} TensorAnnotation instance + */ + TensorAnnotation.create = function create(properties) { + return new TensorAnnotation(properties); + }; + + /** + * Encodes the specified TensorAnnotation message. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. + * @function encode + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorAnnotation.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.tensorName != null && Object.hasOwnProperty.call(message, "tensorName")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.tensorName); + if (message.quantParameterTensorNames != null && message.quantParameterTensorNames.length) + for (var i = 0; i < message.quantParameterTensorNames.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.quantParameterTensorNames[i], writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TensorAnnotation message, length delimited. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorAnnotation.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorAnnotation + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorAnnotation} TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorAnnotation.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorAnnotation(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.tensorName = reader.string(); + break; + } + case 2: { + if (!(message.quantParameterTensorNames && message.quantParameterTensorNames.length)) + message.quantParameterTensorNames = []; + message.quantParameterTensorNames.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorAnnotation + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorAnnotation} TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorAnnotation.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorAnnotation message. + * @function verify + * @memberof onnx.TensorAnnotation + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorAnnotation.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.tensorName != null && message.hasOwnProperty("tensorName")) + if (!$util.isString(message.tensorName)) + return "tensorName: string expected"; + if (message.quantParameterTensorNames != null && message.hasOwnProperty("quantParameterTensorNames")) { + if (!Array.isArray(message.quantParameterTensorNames)) + return "quantParameterTensorNames: array expected"; + for (var i = 0; i < message.quantParameterTensorNames.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.quantParameterTensorNames[i]); + if (error) + return "quantParameterTensorNames." + error; + } + } + return null; + }; + + /** + * Creates a TensorAnnotation message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorAnnotation + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorAnnotation} TensorAnnotation + */ + TensorAnnotation.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorAnnotation) + return object; + var message = new $root.onnx.TensorAnnotation(); + if (object.tensorName != null) + message.tensorName = String(object.tensorName); + if (object.quantParameterTensorNames) { + if (!Array.isArray(object.quantParameterTensorNames)) + throw TypeError(".onnx.TensorAnnotation.quantParameterTensorNames: array expected"); + message.quantParameterTensorNames = []; + for (var i = 0; i < object.quantParameterTensorNames.length; ++i) { + if (typeof object.quantParameterTensorNames[i] !== "object") + throw TypeError(".onnx.TensorAnnotation.quantParameterTensorNames: object expected"); + message.quantParameterTensorNames[i] = $root.onnx.StringStringEntryProto.fromObject(object.quantParameterTensorNames[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TensorAnnotation message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.TensorAnnotation} message TensorAnnotation + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorAnnotation.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) + object.quantParameterTensorNames = []; + if (options.defaults) + object.tensorName = ""; + if (message.tensorName != null && message.hasOwnProperty("tensorName")) + object.tensorName = message.tensorName; + if (message.quantParameterTensorNames && message.quantParameterTensorNames.length) { + object.quantParameterTensorNames = []; + for (var j = 0; j < message.quantParameterTensorNames.length; ++j) + object.quantParameterTensorNames[j] = $root.onnx.StringStringEntryProto.toObject(message.quantParameterTensorNames[j], options); + } + return object; + }; + + /** + * Converts this TensorAnnotation to JSON. + * @function toJSON + * @memberof onnx.TensorAnnotation + * @instance + * @returns {Object.} JSON object + */ + TensorAnnotation.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorAnnotation + * @function getTypeUrl + * @memberof onnx.TensorAnnotation + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorAnnotation.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorAnnotation"; + }; + + return TensorAnnotation; + })(); + + onnx.GraphProto = (function() { + + /** + * Properties of a GraphProto. + * @memberof onnx + * @interface IGraphProto + * @property {Array.|null} [node] GraphProto node + * @property {string|null} [name] GraphProto name + * @property {Array.|null} [initializer] GraphProto initializer + * @property {Array.|null} [sparseInitializer] GraphProto sparseInitializer + * @property {string|null} [docString] GraphProto docString + * @property {Array.|null} [input] GraphProto input + * @property {Array.|null} [output] GraphProto output + * @property {Array.|null} [valueInfo] GraphProto valueInfo + * @property {Array.|null} [quantizationAnnotation] GraphProto quantizationAnnotation + */ + + /** + * Constructs a new GraphProto. + * @memberof onnx + * @classdesc Represents a GraphProto. + * @implements IGraphProto + * @constructor + * @param {onnx.IGraphProto=} [properties] Properties to set + */ + function GraphProto(properties) { + this.node = []; + this.initializer = []; + this.sparseInitializer = []; + this.input = []; + this.output = []; + this.valueInfo = []; + this.quantizationAnnotation = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * GraphProto node. + * @member {Array.} node + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.node = $util.emptyArray; + + /** + * GraphProto name. + * @member {string} name + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.name = ""; + + /** + * GraphProto initializer. + * @member {Array.} initializer + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.initializer = $util.emptyArray; + + /** + * GraphProto sparseInitializer. + * @member {Array.} sparseInitializer + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.sparseInitializer = $util.emptyArray; + + /** + * GraphProto docString. + * @member {string} docString + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.docString = ""; + + /** + * GraphProto input. + * @member {Array.} input + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.input = $util.emptyArray; + + /** + * GraphProto output. + * @member {Array.} output + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.output = $util.emptyArray; + + /** + * GraphProto valueInfo. + * @member {Array.} valueInfo + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.valueInfo = $util.emptyArray; + + /** + * GraphProto quantizationAnnotation. + * @member {Array.} quantizationAnnotation + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.quantizationAnnotation = $util.emptyArray; + + /** + * Creates a new GraphProto instance using the specified properties. + * @function create + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto=} [properties] Properties to set + * @returns {onnx.GraphProto} GraphProto instance + */ + GraphProto.create = function create(properties) { + return new GraphProto(properties); + }; + + /** + * Encodes the specified GraphProto message. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @function encode + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto} message GraphProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + GraphProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.node != null && message.node.length) + for (var i = 0; i < message.node.length; ++i) + $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.name); + if (message.initializer != null && message.initializer.length) + for (var i = 0; i < message.initializer.length; ++i) + $root.onnx.TensorProto.encode(message.initializer[i], writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 10, wireType 2 =*/82).string(message.docString); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + $root.onnx.ValueInfoProto.encode(message.input[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + $root.onnx.ValueInfoProto.encode(message.output[i], writer.uint32(/* id 12, wireType 2 =*/98).fork()).ldelim(); + if (message.valueInfo != null && message.valueInfo.length) + for (var i = 0; i < message.valueInfo.length; ++i) + $root.onnx.ValueInfoProto.encode(message.valueInfo[i], writer.uint32(/* id 13, wireType 2 =*/106).fork()).ldelim(); + if (message.quantizationAnnotation != null && message.quantizationAnnotation.length) + for (var i = 0; i < message.quantizationAnnotation.length; ++i) + $root.onnx.TensorAnnotation.encode(message.quantizationAnnotation[i], writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); + if (message.sparseInitializer != null && message.sparseInitializer.length) + for (var i = 0; i < message.sparseInitializer.length; ++i) + $root.onnx.SparseTensorProto.encode(message.sparseInitializer[i], writer.uint32(/* id 15, wireType 2 =*/122).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified GraphProto message, length delimited. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto} message GraphProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + GraphProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a GraphProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.GraphProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.GraphProto} GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + GraphProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.GraphProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.node && message.node.length)) + message.node = []; + message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); + break; + } + case 2: { + message.name = reader.string(); + break; + } + case 5: { + if (!(message.initializer && message.initializer.length)) + message.initializer = []; + message.initializer.push($root.onnx.TensorProto.decode(reader, reader.uint32())); + break; + } + case 15: { + if (!(message.sparseInitializer && message.sparseInitializer.length)) + message.sparseInitializer = []; + message.sparseInitializer.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); + break; + } + case 10: { + message.docString = reader.string(); + break; + } + case 11: { + if (!(message.input && message.input.length)) + message.input = []; + message.input.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 12: { + if (!(message.output && message.output.length)) + message.output = []; + message.output.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 13: { + if (!(message.valueInfo && message.valueInfo.length)) + message.valueInfo = []; + message.valueInfo.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 14: { + if (!(message.quantizationAnnotation && message.quantizationAnnotation.length)) + message.quantizationAnnotation = []; + message.quantizationAnnotation.push($root.onnx.TensorAnnotation.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a GraphProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.GraphProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.GraphProto} GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + GraphProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a GraphProto message. + * @function verify + * @memberof onnx.GraphProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + GraphProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.node != null && message.hasOwnProperty("node")) { + if (!Array.isArray(message.node)) + return "node: array expected"; + for (var i = 0; i < message.node.length; ++i) { + var error = $root.onnx.NodeProto.verify(message.node[i]); + if (error) + return "node." + error; + } + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.initializer != null && message.hasOwnProperty("initializer")) { + if (!Array.isArray(message.initializer)) + return "initializer: array expected"; + for (var i = 0; i < message.initializer.length; ++i) { + var error = $root.onnx.TensorProto.verify(message.initializer[i]); + if (error) + return "initializer." + error; + } + } + if (message.sparseInitializer != null && message.hasOwnProperty("sparseInitializer")) { + if (!Array.isArray(message.sparseInitializer)) + return "sparseInitializer: array expected"; + for (var i = 0; i < message.sparseInitializer.length; ++i) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseInitializer[i]); + if (error) + return "sparseInitializer." + error; + } + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) + return "input: array expected"; + for (var i = 0; i < message.input.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.input[i]); + if (error) + return "input." + error; + } + } + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) + return "output: array expected"; + for (var i = 0; i < message.output.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.output[i]); + if (error) + return "output." + error; + } + } + if (message.valueInfo != null && message.hasOwnProperty("valueInfo")) { + if (!Array.isArray(message.valueInfo)) + return "valueInfo: array expected"; + for (var i = 0; i < message.valueInfo.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.valueInfo[i]); + if (error) + return "valueInfo." + error; + } + } + if (message.quantizationAnnotation != null && message.hasOwnProperty("quantizationAnnotation")) { + if (!Array.isArray(message.quantizationAnnotation)) + return "quantizationAnnotation: array expected"; + for (var i = 0; i < message.quantizationAnnotation.length; ++i) { + var error = $root.onnx.TensorAnnotation.verify(message.quantizationAnnotation[i]); + if (error) + return "quantizationAnnotation." + error; + } + } + return null; + }; + + /** + * Creates a GraphProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.GraphProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.GraphProto} GraphProto + */ + GraphProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.GraphProto) + return object; + var message = new $root.onnx.GraphProto(); + if (object.node) { + if (!Array.isArray(object.node)) + throw TypeError(".onnx.GraphProto.node: array expected"); + message.node = []; + for (var i = 0; i < object.node.length; ++i) { + if (typeof object.node[i] !== "object") + throw TypeError(".onnx.GraphProto.node: object expected"); + message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); + } + } + if (object.name != null) + message.name = String(object.name); + if (object.initializer) { + if (!Array.isArray(object.initializer)) + throw TypeError(".onnx.GraphProto.initializer: array expected"); + message.initializer = []; + for (var i = 0; i < object.initializer.length; ++i) { + if (typeof object.initializer[i] !== "object") + throw TypeError(".onnx.GraphProto.initializer: object expected"); + message.initializer[i] = $root.onnx.TensorProto.fromObject(object.initializer[i]); + } + } + if (object.sparseInitializer) { + if (!Array.isArray(object.sparseInitializer)) + throw TypeError(".onnx.GraphProto.sparseInitializer: array expected"); + message.sparseInitializer = []; + for (var i = 0; i < object.sparseInitializer.length; ++i) { + if (typeof object.sparseInitializer[i] !== "object") + throw TypeError(".onnx.GraphProto.sparseInitializer: object expected"); + message.sparseInitializer[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseInitializer[i]); + } + } + if (object.docString != null) + message.docString = String(object.docString); + if (object.input) { + if (!Array.isArray(object.input)) + throw TypeError(".onnx.GraphProto.input: array expected"); + message.input = []; + for (var i = 0; i < object.input.length; ++i) { + if (typeof object.input[i] !== "object") + throw TypeError(".onnx.GraphProto.input: object expected"); + message.input[i] = $root.onnx.ValueInfoProto.fromObject(object.input[i]); + } + } + if (object.output) { + if (!Array.isArray(object.output)) + throw TypeError(".onnx.GraphProto.output: array expected"); + message.output = []; + for (var i = 0; i < object.output.length; ++i) { + if (typeof object.output[i] !== "object") + throw TypeError(".onnx.GraphProto.output: object expected"); + message.output[i] = $root.onnx.ValueInfoProto.fromObject(object.output[i]); + } + } + if (object.valueInfo) { + if (!Array.isArray(object.valueInfo)) + throw TypeError(".onnx.GraphProto.valueInfo: array expected"); + message.valueInfo = []; + for (var i = 0; i < object.valueInfo.length; ++i) { + if (typeof object.valueInfo[i] !== "object") + throw TypeError(".onnx.GraphProto.valueInfo: object expected"); + message.valueInfo[i] = $root.onnx.ValueInfoProto.fromObject(object.valueInfo[i]); + } + } + if (object.quantizationAnnotation) { + if (!Array.isArray(object.quantizationAnnotation)) + throw TypeError(".onnx.GraphProto.quantizationAnnotation: array expected"); + message.quantizationAnnotation = []; + for (var i = 0; i < object.quantizationAnnotation.length; ++i) { + if (typeof object.quantizationAnnotation[i] !== "object") + throw TypeError(".onnx.GraphProto.quantizationAnnotation: object expected"); + message.quantizationAnnotation[i] = $root.onnx.TensorAnnotation.fromObject(object.quantizationAnnotation[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a GraphProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.GraphProto + * @static + * @param {onnx.GraphProto} message GraphProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + GraphProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.node = []; + object.initializer = []; + object.input = []; + object.output = []; + object.valueInfo = []; + object.quantizationAnnotation = []; + object.sparseInitializer = []; + } + if (options.defaults) { + object.name = ""; + object.docString = ""; + } + if (message.node && message.node.length) { + object.node = []; + for (var j = 0; j < message.node.length; ++j) + object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.initializer && message.initializer.length) { + object.initializer = []; + for (var j = 0; j < message.initializer.length; ++j) + object.initializer[j] = $root.onnx.TensorProto.toObject(message.initializer[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = $root.onnx.ValueInfoProto.toObject(message.input[j], options); + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = $root.onnx.ValueInfoProto.toObject(message.output[j], options); + } + if (message.valueInfo && message.valueInfo.length) { + object.valueInfo = []; + for (var j = 0; j < message.valueInfo.length; ++j) + object.valueInfo[j] = $root.onnx.ValueInfoProto.toObject(message.valueInfo[j], options); + } + if (message.quantizationAnnotation && message.quantizationAnnotation.length) { + object.quantizationAnnotation = []; + for (var j = 0; j < message.quantizationAnnotation.length; ++j) + object.quantizationAnnotation[j] = $root.onnx.TensorAnnotation.toObject(message.quantizationAnnotation[j], options); + } + if (message.sparseInitializer && message.sparseInitializer.length) { + object.sparseInitializer = []; + for (var j = 0; j < message.sparseInitializer.length; ++j) + object.sparseInitializer[j] = $root.onnx.SparseTensorProto.toObject(message.sparseInitializer[j], options); + } + return object; + }; + + /** + * Converts this GraphProto to JSON. + * @function toJSON + * @memberof onnx.GraphProto + * @instance + * @returns {Object.} JSON object + */ + GraphProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for GraphProto + * @function getTypeUrl + * @memberof onnx.GraphProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + GraphProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.GraphProto"; + }; + + return GraphProto; + })(); + + onnx.TensorProto = (function() { + + /** + * Properties of a TensorProto. + * @memberof onnx + * @interface ITensorProto + * @property {Array.|null} [dims] TensorProto dims + * @property {number|null} [dataType] TensorProto dataType + * @property {onnx.TensorProto.ISegment|null} [segment] TensorProto segment + * @property {Array.|null} [floatData] TensorProto floatData + * @property {Array.|null} [int32Data] TensorProto int32Data + * @property {Array.|null} [stringData] TensorProto stringData + * @property {Array.|null} [int64Data] TensorProto int64Data + * @property {string|null} [name] TensorProto name + * @property {string|null} [docString] TensorProto docString + * @property {Uint8Array|null} [rawData] TensorProto rawData + * @property {Array.|null} [externalData] TensorProto externalData + * @property {onnx.TensorProto.DataLocation|null} [dataLocation] TensorProto dataLocation + * @property {Array.|null} [doubleData] TensorProto doubleData + * @property {Array.|null} [uint64Data] TensorProto uint64Data + */ + + /** + * Constructs a new TensorProto. + * @memberof onnx + * @classdesc Represents a TensorProto. + * @implements ITensorProto + * @constructor + * @param {onnx.ITensorProto=} [properties] Properties to set + */ + function TensorProto(properties) { + this.dims = []; + this.floatData = []; + this.int32Data = []; + this.stringData = []; + this.int64Data = []; + this.externalData = []; + this.doubleData = []; + this.uint64Data = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorProto dims. + * @member {Array.} dims + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dims = $util.emptyArray; + + /** + * TensorProto dataType. + * @member {number} dataType + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dataType = 0; + + /** + * TensorProto segment. + * @member {onnx.TensorProto.ISegment|null|undefined} segment + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.segment = null; + + /** + * TensorProto floatData. + * @member {Array.} floatData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.floatData = $util.emptyArray; + + /** + * TensorProto int32Data. + * @member {Array.} int32Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.int32Data = $util.emptyArray; + + /** + * TensorProto stringData. + * @member {Array.} stringData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.stringData = $util.emptyArray; + + /** + * TensorProto int64Data. + * @member {Array.} int64Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.int64Data = $util.emptyArray; + + /** + * TensorProto name. + * @member {string} name + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.name = ""; + + /** + * TensorProto docString. + * @member {string} docString + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.docString = ""; + + /** + * TensorProto rawData. + * @member {Uint8Array} rawData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.rawData = $util.newBuffer([]); + + /** + * TensorProto externalData. + * @member {Array.} externalData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.externalData = $util.emptyArray; + + /** + * TensorProto dataLocation. + * @member {onnx.TensorProto.DataLocation} dataLocation + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dataLocation = 0; + + /** + * TensorProto doubleData. + * @member {Array.} doubleData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.doubleData = $util.emptyArray; + + /** + * TensorProto uint64Data. + * @member {Array.} uint64Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.uint64Data = $util.emptyArray; + + /** + * Creates a new TensorProto instance using the specified properties. + * @function create + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto=} [properties] Properties to set + * @returns {onnx.TensorProto} TensorProto instance + */ + TensorProto.create = function create(properties) { + return new TensorProto(properties); + }; + + /** + * Encodes the specified TensorProto message. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @function encode + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto} message TensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.dims != null && message.dims.length) { + writer.uint32(/* id 1, wireType 2 =*/10).fork(); + for (var i = 0; i < message.dims.length; ++i) + writer.int64(message.dims[i]); + writer.ldelim(); + } + if (message.dataType != null && Object.hasOwnProperty.call(message, "dataType")) + writer.uint32(/* id 2, wireType 0 =*/16).int32(message.dataType); + if (message.segment != null && Object.hasOwnProperty.call(message, "segment")) + $root.onnx.TensorProto.Segment.encode(message.segment, writer.uint32(/* id 3, wireType 2 =*/26).fork()).ldelim(); + if (message.floatData != null && message.floatData.length) { + writer.uint32(/* id 4, wireType 2 =*/34).fork(); + for (var i = 0; i < message.floatData.length; ++i) + writer.float(message.floatData[i]); + writer.ldelim(); + } + if (message.int32Data != null && message.int32Data.length) { + writer.uint32(/* id 5, wireType 2 =*/42).fork(); + for (var i = 0; i < message.int32Data.length; ++i) + writer.int32(message.int32Data[i]); + writer.ldelim(); + } + if (message.stringData != null && message.stringData.length) + for (var i = 0; i < message.stringData.length; ++i) + writer.uint32(/* id 6, wireType 2 =*/50).bytes(message.stringData[i]); + if (message.int64Data != null && message.int64Data.length) { + writer.uint32(/* id 7, wireType 2 =*/58).fork(); + for (var i = 0; i < message.int64Data.length; ++i) + writer.int64(message.int64Data[i]); + writer.ldelim(); + } + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 8, wireType 2 =*/66).string(message.name); + if (message.rawData != null && Object.hasOwnProperty.call(message, "rawData")) + writer.uint32(/* id 9, wireType 2 =*/74).bytes(message.rawData); + if (message.doubleData != null && message.doubleData.length) { + writer.uint32(/* id 10, wireType 2 =*/82).fork(); + for (var i = 0; i < message.doubleData.length; ++i) + writer.double(message.doubleData[i]); + writer.ldelim(); + } + if (message.uint64Data != null && message.uint64Data.length) { + writer.uint32(/* id 11, wireType 2 =*/90).fork(); + for (var i = 0; i < message.uint64Data.length; ++i) + writer.uint64(message.uint64Data[i]); + writer.ldelim(); + } + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 12, wireType 2 =*/98).string(message.docString); + if (message.externalData != null && message.externalData.length) + for (var i = 0; i < message.externalData.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.externalData[i], writer.uint32(/* id 13, wireType 2 =*/106).fork()).ldelim(); + if (message.dataLocation != null && Object.hasOwnProperty.call(message, "dataLocation")) + writer.uint32(/* id 14, wireType 0 =*/112).int32(message.dataLocation); + return writer; + }; + + /** + * Encodes the specified TensorProto message, length delimited. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto} message TensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorProto} TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.dims && message.dims.length)) + message.dims = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.dims.push(reader.int64()); + } else + message.dims.push(reader.int64()); + break; + } + case 2: { + message.dataType = reader.int32(); + break; + } + case 3: { + message.segment = $root.onnx.TensorProto.Segment.decode(reader, reader.uint32()); + break; + } + case 4: { + if (!(message.floatData && message.floatData.length)) + message.floatData = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.floatData.push(reader.float()); + } else + message.floatData.push(reader.float()); + break; + } + case 5: { + if (!(message.int32Data && message.int32Data.length)) + message.int32Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.int32Data.push(reader.int32()); + } else + message.int32Data.push(reader.int32()); + break; + } + case 6: { + if (!(message.stringData && message.stringData.length)) + message.stringData = []; + message.stringData.push(reader.bytes()); + break; + } + case 7: { + if (!(message.int64Data && message.int64Data.length)) + message.int64Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.int64Data.push(reader.int64()); + } else + message.int64Data.push(reader.int64()); + break; + } + case 8: { + message.name = reader.string(); + break; + } + case 12: { + message.docString = reader.string(); + break; + } + case 9: { + message.rawData = reader.bytes(); + break; + } + case 13: { + if (!(message.externalData && message.externalData.length)) + message.externalData = []; + message.externalData.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 14: { + message.dataLocation = reader.int32(); + break; + } + case 10: { + if (!(message.doubleData && message.doubleData.length)) + message.doubleData = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.doubleData.push(reader.double()); + } else + message.doubleData.push(reader.double()); + break; + } + case 11: { + if (!(message.uint64Data && message.uint64Data.length)) + message.uint64Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.uint64Data.push(reader.uint64()); + } else + message.uint64Data.push(reader.uint64()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorProto} TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorProto message. + * @function verify + * @memberof onnx.TensorProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.dims != null && message.hasOwnProperty("dims")) { + if (!Array.isArray(message.dims)) + return "dims: array expected"; + for (var i = 0; i < message.dims.length; ++i) + if (!$util.isInteger(message.dims[i]) && !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high))) + return "dims: integer|Long[] expected"; + } + if (message.dataType != null && message.hasOwnProperty("dataType")) + if (!$util.isInteger(message.dataType)) + return "dataType: integer expected"; + if (message.segment != null && message.hasOwnProperty("segment")) { + var error = $root.onnx.TensorProto.Segment.verify(message.segment); + if (error) + return "segment." + error; + } + if (message.floatData != null && message.hasOwnProperty("floatData")) { + if (!Array.isArray(message.floatData)) + return "floatData: array expected"; + for (var i = 0; i < message.floatData.length; ++i) + if (typeof message.floatData[i] !== "number") + return "floatData: number[] expected"; + } + if (message.int32Data != null && message.hasOwnProperty("int32Data")) { + if (!Array.isArray(message.int32Data)) + return "int32Data: array expected"; + for (var i = 0; i < message.int32Data.length; ++i) + if (!$util.isInteger(message.int32Data[i])) + return "int32Data: integer[] expected"; + } + if (message.stringData != null && message.hasOwnProperty("stringData")) { + if (!Array.isArray(message.stringData)) + return "stringData: array expected"; + for (var i = 0; i < message.stringData.length; ++i) + if (!(message.stringData[i] && typeof message.stringData[i].length === "number" || $util.isString(message.stringData[i]))) + return "stringData: buffer[] expected"; + } + if (message.int64Data != null && message.hasOwnProperty("int64Data")) { + if (!Array.isArray(message.int64Data)) + return "int64Data: array expected"; + for (var i = 0; i < message.int64Data.length; ++i) + if (!$util.isInteger(message.int64Data[i]) && !(message.int64Data[i] && $util.isInteger(message.int64Data[i].low) && $util.isInteger(message.int64Data[i].high))) + return "int64Data: integer|Long[] expected"; + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.rawData != null && message.hasOwnProperty("rawData")) + if (!(message.rawData && typeof message.rawData.length === "number" || $util.isString(message.rawData))) + return "rawData: buffer expected"; + if (message.externalData != null && message.hasOwnProperty("externalData")) { + if (!Array.isArray(message.externalData)) + return "externalData: array expected"; + for (var i = 0; i < message.externalData.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.externalData[i]); + if (error) + return "externalData." + error; + } + } + if (message.dataLocation != null && message.hasOwnProperty("dataLocation")) + switch (message.dataLocation) { + default: + return "dataLocation: enum value expected"; + case 0: + case 1: + break; + } + if (message.doubleData != null && message.hasOwnProperty("doubleData")) { + if (!Array.isArray(message.doubleData)) + return "doubleData: array expected"; + for (var i = 0; i < message.doubleData.length; ++i) + if (typeof message.doubleData[i] !== "number") + return "doubleData: number[] expected"; + } + if (message.uint64Data != null && message.hasOwnProperty("uint64Data")) { + if (!Array.isArray(message.uint64Data)) + return "uint64Data: array expected"; + for (var i = 0; i < message.uint64Data.length; ++i) + if (!$util.isInteger(message.uint64Data[i]) && !(message.uint64Data[i] && $util.isInteger(message.uint64Data[i].low) && $util.isInteger(message.uint64Data[i].high))) + return "uint64Data: integer|Long[] expected"; + } + return null; + }; + + /** + * Creates a TensorProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorProto} TensorProto + */ + TensorProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorProto) + return object; + var message = new $root.onnx.TensorProto(); + if (object.dims) { + if (!Array.isArray(object.dims)) + throw TypeError(".onnx.TensorProto.dims: array expected"); + message.dims = []; + for (var i = 0; i < object.dims.length; ++i) + if ($util.Long) + (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; + else if (typeof object.dims[i] === "string") + message.dims[i] = parseInt(object.dims[i], 10); + else if (typeof object.dims[i] === "number") + message.dims[i] = object.dims[i]; + else if (typeof object.dims[i] === "object") + message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); + } + if (object.dataType != null) + message.dataType = object.dataType | 0; + if (object.segment != null) { + if (typeof object.segment !== "object") + throw TypeError(".onnx.TensorProto.segment: object expected"); + message.segment = $root.onnx.TensorProto.Segment.fromObject(object.segment); + } + if (object.floatData) { + if (!Array.isArray(object.floatData)) + throw TypeError(".onnx.TensorProto.floatData: array expected"); + message.floatData = []; + for (var i = 0; i < object.floatData.length; ++i) + message.floatData[i] = Number(object.floatData[i]); + } + if (object.int32Data) { + if (!Array.isArray(object.int32Data)) + throw TypeError(".onnx.TensorProto.int32Data: array expected"); + message.int32Data = []; + for (var i = 0; i < object.int32Data.length; ++i) + message.int32Data[i] = object.int32Data[i] | 0; + } + if (object.stringData) { + if (!Array.isArray(object.stringData)) + throw TypeError(".onnx.TensorProto.stringData: array expected"); + message.stringData = []; + for (var i = 0; i < object.stringData.length; ++i) + if (typeof object.stringData[i] === "string") + $util.base64.decode(object.stringData[i], message.stringData[i] = $util.newBuffer($util.base64.length(object.stringData[i])), 0); + else if (object.stringData[i].length >= 0) + message.stringData[i] = object.stringData[i]; + } + if (object.int64Data) { + if (!Array.isArray(object.int64Data)) + throw TypeError(".onnx.TensorProto.int64Data: array expected"); + message.int64Data = []; + for (var i = 0; i < object.int64Data.length; ++i) + if ($util.Long) + (message.int64Data[i] = $util.Long.fromValue(object.int64Data[i])).unsigned = false; + else if (typeof object.int64Data[i] === "string") + message.int64Data[i] = parseInt(object.int64Data[i], 10); + else if (typeof object.int64Data[i] === "number") + message.int64Data[i] = object.int64Data[i]; + else if (typeof object.int64Data[i] === "object") + message.int64Data[i] = new $util.LongBits(object.int64Data[i].low >>> 0, object.int64Data[i].high >>> 0).toNumber(); + } + if (object.name != null) + message.name = String(object.name); + if (object.docString != null) + message.docString = String(object.docString); + if (object.rawData != null) + if (typeof object.rawData === "string") + $util.base64.decode(object.rawData, message.rawData = $util.newBuffer($util.base64.length(object.rawData)), 0); + else if (object.rawData.length >= 0) + message.rawData = object.rawData; + if (object.externalData) { + if (!Array.isArray(object.externalData)) + throw TypeError(".onnx.TensorProto.externalData: array expected"); + message.externalData = []; + for (var i = 0; i < object.externalData.length; ++i) { + if (typeof object.externalData[i] !== "object") + throw TypeError(".onnx.TensorProto.externalData: object expected"); + message.externalData[i] = $root.onnx.StringStringEntryProto.fromObject(object.externalData[i]); + } + } + switch (object.dataLocation) { + default: + if (typeof object.dataLocation === "number") { + message.dataLocation = object.dataLocation; + break; + } + break; + case "DEFAULT": + case 0: + message.dataLocation = 0; + break; + case "EXTERNAL": + case 1: + message.dataLocation = 1; + break; + } + if (object.doubleData) { + if (!Array.isArray(object.doubleData)) + throw TypeError(".onnx.TensorProto.doubleData: array expected"); + message.doubleData = []; + for (var i = 0; i < object.doubleData.length; ++i) + message.doubleData[i] = Number(object.doubleData[i]); + } + if (object.uint64Data) { + if (!Array.isArray(object.uint64Data)) + throw TypeError(".onnx.TensorProto.uint64Data: array expected"); + message.uint64Data = []; + for (var i = 0; i < object.uint64Data.length; ++i) + if ($util.Long) + (message.uint64Data[i] = $util.Long.fromValue(object.uint64Data[i])).unsigned = true; + else if (typeof object.uint64Data[i] === "string") + message.uint64Data[i] = parseInt(object.uint64Data[i], 10); + else if (typeof object.uint64Data[i] === "number") + message.uint64Data[i] = object.uint64Data[i]; + else if (typeof object.uint64Data[i] === "object") + message.uint64Data[i] = new $util.LongBits(object.uint64Data[i].low >>> 0, object.uint64Data[i].high >>> 0).toNumber(true); + } + return message; + }; + + /** + * Creates a plain object from a TensorProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorProto + * @static + * @param {onnx.TensorProto} message TensorProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.dims = []; + object.floatData = []; + object.int32Data = []; + object.stringData = []; + object.int64Data = []; + object.doubleData = []; + object.uint64Data = []; + object.externalData = []; + } + if (options.defaults) { + object.dataType = 0; + object.segment = null; + object.name = ""; + if (options.bytes === String) + object.rawData = ""; + else { + object.rawData = []; + if (options.bytes !== Array) + object.rawData = $util.newBuffer(object.rawData); + } + object.docString = ""; + object.dataLocation = options.enums === String ? "DEFAULT" : 0; + } + if (message.dims && message.dims.length) { + object.dims = []; + for (var j = 0; j < message.dims.length; ++j) + if (typeof message.dims[j] === "number") + object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; + else + object.dims[j] = options.longs === String ? $util.Long.prototype.toString.call(message.dims[j]) : options.longs === Number ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() : message.dims[j]; + } + if (message.dataType != null && message.hasOwnProperty("dataType")) + object.dataType = message.dataType; + if (message.segment != null && message.hasOwnProperty("segment")) + object.segment = $root.onnx.TensorProto.Segment.toObject(message.segment, options); + if (message.floatData && message.floatData.length) { + object.floatData = []; + for (var j = 0; j < message.floatData.length; ++j) + object.floatData[j] = options.json && !isFinite(message.floatData[j]) ? String(message.floatData[j]) : message.floatData[j]; + } + if (message.int32Data && message.int32Data.length) { + object.int32Data = []; + for (var j = 0; j < message.int32Data.length; ++j) + object.int32Data[j] = message.int32Data[j]; + } + if (message.stringData && message.stringData.length) { + object.stringData = []; + for (var j = 0; j < message.stringData.length; ++j) + object.stringData[j] = options.bytes === String ? $util.base64.encode(message.stringData[j], 0, message.stringData[j].length) : options.bytes === Array ? Array.prototype.slice.call(message.stringData[j]) : message.stringData[j]; + } + if (message.int64Data && message.int64Data.length) { + object.int64Data = []; + for (var j = 0; j < message.int64Data.length; ++j) + if (typeof message.int64Data[j] === "number") + object.int64Data[j] = options.longs === String ? String(message.int64Data[j]) : message.int64Data[j]; + else + object.int64Data[j] = options.longs === String ? $util.Long.prototype.toString.call(message.int64Data[j]) : options.longs === Number ? new $util.LongBits(message.int64Data[j].low >>> 0, message.int64Data[j].high >>> 0).toNumber() : message.int64Data[j]; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.rawData != null && message.hasOwnProperty("rawData")) + object.rawData = options.bytes === String ? $util.base64.encode(message.rawData, 0, message.rawData.length) : options.bytes === Array ? Array.prototype.slice.call(message.rawData) : message.rawData; + if (message.doubleData && message.doubleData.length) { + object.doubleData = []; + for (var j = 0; j < message.doubleData.length; ++j) + object.doubleData[j] = options.json && !isFinite(message.doubleData[j]) ? String(message.doubleData[j]) : message.doubleData[j]; + } + if (message.uint64Data && message.uint64Data.length) { + object.uint64Data = []; + for (var j = 0; j < message.uint64Data.length; ++j) + if (typeof message.uint64Data[j] === "number") + object.uint64Data[j] = options.longs === String ? String(message.uint64Data[j]) : message.uint64Data[j]; + else + object.uint64Data[j] = options.longs === String ? $util.Long.prototype.toString.call(message.uint64Data[j]) : options.longs === Number ? new $util.LongBits(message.uint64Data[j].low >>> 0, message.uint64Data[j].high >>> 0).toNumber(true) : message.uint64Data[j]; + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.externalData && message.externalData.length) { + object.externalData = []; + for (var j = 0; j < message.externalData.length; ++j) + object.externalData[j] = $root.onnx.StringStringEntryProto.toObject(message.externalData[j], options); + } + if (message.dataLocation != null && message.hasOwnProperty("dataLocation")) + object.dataLocation = options.enums === String ? $root.onnx.TensorProto.DataLocation[message.dataLocation] === undefined ? message.dataLocation : $root.onnx.TensorProto.DataLocation[message.dataLocation] : message.dataLocation; + return object; + }; + + /** + * Converts this TensorProto to JSON. + * @function toJSON + * @memberof onnx.TensorProto + * @instance + * @returns {Object.} JSON object + */ + TensorProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorProto + * @function getTypeUrl + * @memberof onnx.TensorProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorProto"; + }; + + /** + * DataType enum. + * @name onnx.TensorProto.DataType + * @enum {number} + * @property {number} UNDEFINED=0 UNDEFINED value + * @property {number} FLOAT=1 FLOAT value + * @property {number} UINT8=2 UINT8 value + * @property {number} INT8=3 INT8 value + * @property {number} UINT16=4 UINT16 value + * @property {number} INT16=5 INT16 value + * @property {number} INT32=6 INT32 value + * @property {number} INT64=7 INT64 value + * @property {number} STRING=8 STRING value + * @property {number} BOOL=9 BOOL value + * @property {number} FLOAT16=10 FLOAT16 value + * @property {number} DOUBLE=11 DOUBLE value + * @property {number} UINT32=12 UINT32 value + * @property {number} UINT64=13 UINT64 value + * @property {number} COMPLEX64=14 COMPLEX64 value + * @property {number} COMPLEX128=15 COMPLEX128 value + * @property {number} BFLOAT16=16 BFLOAT16 value + * @property {number} FLOAT8E4M3FN=17 FLOAT8E4M3FN value + * @property {number} FLOAT8E4M3FNUZ=18 FLOAT8E4M3FNUZ value + * @property {number} FLOAT8E5M2=19 FLOAT8E5M2 value + * @property {number} FLOAT8E5M2FNUZ=20 FLOAT8E5M2FNUZ value + */ + TensorProto.DataType = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "UNDEFINED"] = 0; + values[valuesById[1] = "FLOAT"] = 1; + values[valuesById[2] = "UINT8"] = 2; + values[valuesById[3] = "INT8"] = 3; + values[valuesById[4] = "UINT16"] = 4; + values[valuesById[5] = "INT16"] = 5; + values[valuesById[6] = "INT32"] = 6; + values[valuesById[7] = "INT64"] = 7; + values[valuesById[8] = "STRING"] = 8; + values[valuesById[9] = "BOOL"] = 9; + values[valuesById[10] = "FLOAT16"] = 10; + values[valuesById[11] = "DOUBLE"] = 11; + values[valuesById[12] = "UINT32"] = 12; + values[valuesById[13] = "UINT64"] = 13; + values[valuesById[14] = "COMPLEX64"] = 14; + values[valuesById[15] = "COMPLEX128"] = 15; + values[valuesById[16] = "BFLOAT16"] = 16; + values[valuesById[17] = "FLOAT8E4M3FN"] = 17; + values[valuesById[18] = "FLOAT8E4M3FNUZ"] = 18; + values[valuesById[19] = "FLOAT8E5M2"] = 19; + values[valuesById[20] = "FLOAT8E5M2FNUZ"] = 20; + return values; + })(); + + TensorProto.Segment = (function() { + + /** + * Properties of a Segment. + * @memberof onnx.TensorProto + * @interface ISegment + * @property {number|Long|null} [begin] Segment begin + * @property {number|Long|null} [end] Segment end + */ + + /** + * Constructs a new Segment. + * @memberof onnx.TensorProto + * @classdesc Represents a Segment. + * @implements ISegment + * @constructor + * @param {onnx.TensorProto.ISegment=} [properties] Properties to set + */ + function Segment(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Segment begin. + * @member {number|Long} begin + * @memberof onnx.TensorProto.Segment + * @instance + */ + Segment.prototype.begin = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * Segment end. + * @member {number|Long} end + * @memberof onnx.TensorProto.Segment + * @instance + */ + Segment.prototype.end = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * Creates a new Segment instance using the specified properties. + * @function create + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment=} [properties] Properties to set + * @returns {onnx.TensorProto.Segment} Segment instance + */ + Segment.create = function create(properties) { + return new Segment(properties); + }; + + /** + * Encodes the specified Segment message. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. + * @function encode + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Segment.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.begin != null && Object.hasOwnProperty.call(message, "begin")) + writer.uint32(/* id 1, wireType 0 =*/8).int64(message.begin); + if (message.end != null && Object.hasOwnProperty.call(message, "end")) + writer.uint32(/* id 2, wireType 0 =*/16).int64(message.end); + return writer; + }; + + /** + * Encodes the specified Segment message, length delimited. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Segment.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Segment message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorProto.Segment + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorProto.Segment} Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Segment.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorProto.Segment(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.begin = reader.int64(); + break; + } + case 2: { + message.end = reader.int64(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Segment message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorProto.Segment + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorProto.Segment} Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Segment.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Segment message. + * @function verify + * @memberof onnx.TensorProto.Segment + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Segment.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.begin != null && message.hasOwnProperty("begin")) + if (!$util.isInteger(message.begin) && !(message.begin && $util.isInteger(message.begin.low) && $util.isInteger(message.begin.high))) + return "begin: integer|Long expected"; + if (message.end != null && message.hasOwnProperty("end")) + if (!$util.isInteger(message.end) && !(message.end && $util.isInteger(message.end.low) && $util.isInteger(message.end.high))) + return "end: integer|Long expected"; + return null; + }; + + /** + * Creates a Segment message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorProto.Segment + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorProto.Segment} Segment + */ + Segment.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorProto.Segment) + return object; + var message = new $root.onnx.TensorProto.Segment(); + if (object.begin != null) + if ($util.Long) + (message.begin = $util.Long.fromValue(object.begin)).unsigned = false; + else if (typeof object.begin === "string") + message.begin = parseInt(object.begin, 10); + else if (typeof object.begin === "number") + message.begin = object.begin; + else if (typeof object.begin === "object") + message.begin = new $util.LongBits(object.begin.low >>> 0, object.begin.high >>> 0).toNumber(); + if (object.end != null) + if ($util.Long) + (message.end = $util.Long.fromValue(object.end)).unsigned = false; + else if (typeof object.end === "string") + message.end = parseInt(object.end, 10); + else if (typeof object.end === "number") + message.end = object.end; + else if (typeof object.end === "object") + message.end = new $util.LongBits(object.end.low >>> 0, object.end.high >>> 0).toNumber(); + return message; + }; + + /** + * Creates a plain object from a Segment message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.Segment} message Segment + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Segment.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.begin = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.begin = options.longs === String ? "0" : 0; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.end = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.end = options.longs === String ? "0" : 0; + } + if (message.begin != null && message.hasOwnProperty("begin")) + if (typeof message.begin === "number") + object.begin = options.longs === String ? String(message.begin) : message.begin; + else + object.begin = options.longs === String ? $util.Long.prototype.toString.call(message.begin) : options.longs === Number ? new $util.LongBits(message.begin.low >>> 0, message.begin.high >>> 0).toNumber() : message.begin; + if (message.end != null && message.hasOwnProperty("end")) + if (typeof message.end === "number") + object.end = options.longs === String ? String(message.end) : message.end; + else + object.end = options.longs === String ? $util.Long.prototype.toString.call(message.end) : options.longs === Number ? new $util.LongBits(message.end.low >>> 0, message.end.high >>> 0).toNumber() : message.end; + return object; + }; + + /** + * Converts this Segment to JSON. + * @function toJSON + * @memberof onnx.TensorProto.Segment + * @instance + * @returns {Object.} JSON object + */ + Segment.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Segment + * @function getTypeUrl + * @memberof onnx.TensorProto.Segment + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Segment.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorProto.Segment"; + }; + + return Segment; + })(); + + /** + * DataLocation enum. + * @name onnx.TensorProto.DataLocation + * @enum {number} + * @property {number} DEFAULT=0 DEFAULT value + * @property {number} EXTERNAL=1 EXTERNAL value + */ + TensorProto.DataLocation = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "DEFAULT"] = 0; + values[valuesById[1] = "EXTERNAL"] = 1; + return values; + })(); + + return TensorProto; + })(); + + onnx.SparseTensorProto = (function() { + + /** + * Properties of a SparseTensorProto. + * @memberof onnx + * @interface ISparseTensorProto + * @property {onnx.ITensorProto|null} [values] SparseTensorProto values + * @property {onnx.ITensorProto|null} [indices] SparseTensorProto indices + * @property {Array.|null} [dims] SparseTensorProto dims + */ + + /** + * Constructs a new SparseTensorProto. + * @memberof onnx + * @classdesc Represents a SparseTensorProto. + * @implements ISparseTensorProto + * @constructor + * @param {onnx.ISparseTensorProto=} [properties] Properties to set + */ + function SparseTensorProto(properties) { + this.dims = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * SparseTensorProto values. + * @member {onnx.ITensorProto|null|undefined} values + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.values = null; + + /** + * SparseTensorProto indices. + * @member {onnx.ITensorProto|null|undefined} indices + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.indices = null; + + /** + * SparseTensorProto dims. + * @member {Array.} dims + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.dims = $util.emptyArray; + + /** + * Creates a new SparseTensorProto instance using the specified properties. + * @function create + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto=} [properties] Properties to set + * @returns {onnx.SparseTensorProto} SparseTensorProto instance + */ + SparseTensorProto.create = function create(properties) { + return new SparseTensorProto(properties); + }; + + /** + * Encodes the specified SparseTensorProto message. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. + * @function encode + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensorProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.values != null && Object.hasOwnProperty.call(message, "values")) + $root.onnx.TensorProto.encode(message.values, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.indices != null && Object.hasOwnProperty.call(message, "indices")) + $root.onnx.TensorProto.encode(message.indices, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.dims != null && message.dims.length) { + writer.uint32(/* id 3, wireType 2 =*/26).fork(); + for (var i = 0; i < message.dims.length; ++i) + writer.int64(message.dims[i]); + writer.ldelim(); + } + return writer; + }; + + /** + * Encodes the specified SparseTensorProto message, length delimited. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensorProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.SparseTensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.SparseTensorProto} SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensorProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.SparseTensorProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.values = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 2: { + message.indices = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 3: { + if (!(message.dims && message.dims.length)) + message.dims = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.dims.push(reader.int64()); + } else + message.dims.push(reader.int64()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.SparseTensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.SparseTensorProto} SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensorProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a SparseTensorProto message. + * @function verify + * @memberof onnx.SparseTensorProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + SparseTensorProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.values != null && message.hasOwnProperty("values")) { + var error = $root.onnx.TensorProto.verify(message.values); + if (error) + return "values." + error; + } + if (message.indices != null && message.hasOwnProperty("indices")) { + var error = $root.onnx.TensorProto.verify(message.indices); + if (error) + return "indices." + error; + } + if (message.dims != null && message.hasOwnProperty("dims")) { + if (!Array.isArray(message.dims)) + return "dims: array expected"; + for (var i = 0; i < message.dims.length; ++i) + if (!$util.isInteger(message.dims[i]) && !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high))) + return "dims: integer|Long[] expected"; + } + return null; + }; + + /** + * Creates a SparseTensorProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.SparseTensorProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.SparseTensorProto} SparseTensorProto + */ + SparseTensorProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.SparseTensorProto) + return object; + var message = new $root.onnx.SparseTensorProto(); + if (object.values != null) { + if (typeof object.values !== "object") + throw TypeError(".onnx.SparseTensorProto.values: object expected"); + message.values = $root.onnx.TensorProto.fromObject(object.values); + } + if (object.indices != null) { + if (typeof object.indices !== "object") + throw TypeError(".onnx.SparseTensorProto.indices: object expected"); + message.indices = $root.onnx.TensorProto.fromObject(object.indices); + } + if (object.dims) { + if (!Array.isArray(object.dims)) + throw TypeError(".onnx.SparseTensorProto.dims: array expected"); + message.dims = []; + for (var i = 0; i < object.dims.length; ++i) + if ($util.Long) + (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; + else if (typeof object.dims[i] === "string") + message.dims[i] = parseInt(object.dims[i], 10); + else if (typeof object.dims[i] === "number") + message.dims[i] = object.dims[i]; + else if (typeof object.dims[i] === "object") + message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); + } + return message; + }; + + /** + * Creates a plain object from a SparseTensorProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.SparseTensorProto} message SparseTensorProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + SparseTensorProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) + object.dims = []; + if (options.defaults) { + object.values = null; + object.indices = null; + } + if (message.values != null && message.hasOwnProperty("values")) + object.values = $root.onnx.TensorProto.toObject(message.values, options); + if (message.indices != null && message.hasOwnProperty("indices")) + object.indices = $root.onnx.TensorProto.toObject(message.indices, options); + if (message.dims && message.dims.length) { + object.dims = []; + for (var j = 0; j < message.dims.length; ++j) + if (typeof message.dims[j] === "number") + object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; + else + object.dims[j] = options.longs === String ? $util.Long.prototype.toString.call(message.dims[j]) : options.longs === Number ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() : message.dims[j]; + } + return object; + }; + + /** + * Converts this SparseTensorProto to JSON. + * @function toJSON + * @memberof onnx.SparseTensorProto + * @instance + * @returns {Object.} JSON object + */ + SparseTensorProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for SparseTensorProto + * @function getTypeUrl + * @memberof onnx.SparseTensorProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + SparseTensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.SparseTensorProto"; + }; + + return SparseTensorProto; + })(); + + onnx.TensorShapeProto = (function() { + + /** + * Properties of a TensorShapeProto. + * @memberof onnx + * @interface ITensorShapeProto + * @property {Array.|null} [dim] TensorShapeProto dim + */ + + /** + * Constructs a new TensorShapeProto. + * @memberof onnx + * @classdesc Represents a TensorShapeProto. + * @implements ITensorShapeProto + * @constructor + * @param {onnx.ITensorShapeProto=} [properties] Properties to set + */ + function TensorShapeProto(properties) { + this.dim = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorShapeProto dim. + * @member {Array.} dim + * @memberof onnx.TensorShapeProto + * @instance + */ + TensorShapeProto.prototype.dim = $util.emptyArray; + + /** + * Creates a new TensorShapeProto instance using the specified properties. + * @function create + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto=} [properties] Properties to set + * @returns {onnx.TensorShapeProto} TensorShapeProto instance + */ + TensorShapeProto.create = function create(properties) { + return new TensorShapeProto(properties); + }; + + /** + * Encodes the specified TensorShapeProto message. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. + * @function encode + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorShapeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.dim != null && message.dim.length) + for (var i = 0; i < message.dim.length; ++i) + $root.onnx.TensorShapeProto.Dimension.encode(message.dim[i], writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TensorShapeProto message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorShapeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorShapeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorShapeProto} TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorShapeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorShapeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.dim && message.dim.length)) + message.dim = []; + message.dim.push($root.onnx.TensorShapeProto.Dimension.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorShapeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorShapeProto} TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorShapeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorShapeProto message. + * @function verify + * @memberof onnx.TensorShapeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorShapeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.dim != null && message.hasOwnProperty("dim")) { + if (!Array.isArray(message.dim)) + return "dim: array expected"; + for (var i = 0; i < message.dim.length; ++i) { + var error = $root.onnx.TensorShapeProto.Dimension.verify(message.dim[i]); + if (error) + return "dim." + error; + } + } + return null; + }; + + /** + * Creates a TensorShapeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorShapeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorShapeProto} TensorShapeProto + */ + TensorShapeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorShapeProto) + return object; + var message = new $root.onnx.TensorShapeProto(); + if (object.dim) { + if (!Array.isArray(object.dim)) + throw TypeError(".onnx.TensorShapeProto.dim: array expected"); + message.dim = []; + for (var i = 0; i < object.dim.length; ++i) { + if (typeof object.dim[i] !== "object") + throw TypeError(".onnx.TensorShapeProto.dim: object expected"); + message.dim[i] = $root.onnx.TensorShapeProto.Dimension.fromObject(object.dim[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TensorShapeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.TensorShapeProto} message TensorShapeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorShapeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) + object.dim = []; + if (message.dim && message.dim.length) { + object.dim = []; + for (var j = 0; j < message.dim.length; ++j) + object.dim[j] = $root.onnx.TensorShapeProto.Dimension.toObject(message.dim[j], options); + } + return object; + }; + + /** + * Converts this TensorShapeProto to JSON. + * @function toJSON + * @memberof onnx.TensorShapeProto + * @instance + * @returns {Object.} JSON object + */ + TensorShapeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorShapeProto + * @function getTypeUrl + * @memberof onnx.TensorShapeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorShapeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorShapeProto"; + }; + + TensorShapeProto.Dimension = (function() { + + /** + * Properties of a Dimension. + * @memberof onnx.TensorShapeProto + * @interface IDimension + * @property {number|Long|null} [dimValue] Dimension dimValue + * @property {string|null} [dimParam] Dimension dimParam + * @property {string|null} [denotation] Dimension denotation + */ + + /** + * Constructs a new Dimension. + * @memberof onnx.TensorShapeProto + * @classdesc Represents a Dimension. + * @implements IDimension + * @constructor + * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set + */ + function Dimension(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Dimension dimValue. + * @member {number|Long|null|undefined} dimValue + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.dimValue = null; + + /** + * Dimension dimParam. + * @member {string|null|undefined} dimParam + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.dimParam = null; + + /** + * Dimension denotation. + * @member {string} denotation + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.denotation = ""; + + // OneOf field names bound to virtual getters and setters + var $oneOfFields; + + /** + * Dimension value. + * @member {"dimValue"|"dimParam"|undefined} value + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Object.defineProperty(Dimension.prototype, "value", { + get: $util.oneOfGetter($oneOfFields = ["dimValue", "dimParam"]), + set: $util.oneOfSetter($oneOfFields) + }); + + /** + * Creates a new Dimension instance using the specified properties. + * @function create + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set + * @returns {onnx.TensorShapeProto.Dimension} Dimension instance + */ + Dimension.create = function create(properties) { + return new Dimension(properties); + }; + + /** + * Encodes the specified Dimension message. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @function encode + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Dimension.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.dimValue != null && Object.hasOwnProperty.call(message, "dimValue")) + writer.uint32(/* id 1, wireType 0 =*/8).int64(message.dimValue); + if (message.dimParam != null && Object.hasOwnProperty.call(message, "dimParam")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.dimParam); + if (message.denotation != null && Object.hasOwnProperty.call(message, "denotation")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.denotation); + return writer; + }; + + /** + * Encodes the specified Dimension message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Dimension.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Dimension message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorShapeProto.Dimension} Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Dimension.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorShapeProto.Dimension(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.dimValue = reader.int64(); + break; + } + case 2: { + message.dimParam = reader.string(); + break; + } + case 3: { + message.denotation = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Dimension message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorShapeProto.Dimension} Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Dimension.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Dimension message. + * @function verify + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Dimension.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + var properties = {}; + if (message.dimValue != null && message.hasOwnProperty("dimValue")) { + properties.value = 1; + if (!$util.isInteger(message.dimValue) && !(message.dimValue && $util.isInteger(message.dimValue.low) && $util.isInteger(message.dimValue.high))) + return "dimValue: integer|Long expected"; + } + if (message.dimParam != null && message.hasOwnProperty("dimParam")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + if (!$util.isString(message.dimParam)) + return "dimParam: string expected"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + if (!$util.isString(message.denotation)) + return "denotation: string expected"; + return null; + }; + + /** + * Creates a Dimension message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorShapeProto.Dimension} Dimension + */ + Dimension.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorShapeProto.Dimension) + return object; + var message = new $root.onnx.TensorShapeProto.Dimension(); + if (object.dimValue != null) + if ($util.Long) + (message.dimValue = $util.Long.fromValue(object.dimValue)).unsigned = false; + else if (typeof object.dimValue === "string") + message.dimValue = parseInt(object.dimValue, 10); + else if (typeof object.dimValue === "number") + message.dimValue = object.dimValue; + else if (typeof object.dimValue === "object") + message.dimValue = new $util.LongBits(object.dimValue.low >>> 0, object.dimValue.high >>> 0).toNumber(); + if (object.dimParam != null) + message.dimParam = String(object.dimParam); + if (object.denotation != null) + message.denotation = String(object.denotation); + return message; + }; + + /** + * Creates a plain object from a Dimension message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.Dimension} message Dimension + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Dimension.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.denotation = ""; + if (message.dimValue != null && message.hasOwnProperty("dimValue")) { + if (typeof message.dimValue === "number") + object.dimValue = options.longs === String ? String(message.dimValue) : message.dimValue; + else + object.dimValue = options.longs === String ? $util.Long.prototype.toString.call(message.dimValue) : options.longs === Number ? new $util.LongBits(message.dimValue.low >>> 0, message.dimValue.high >>> 0).toNumber() : message.dimValue; + if (options.oneofs) + object.value = "dimValue"; + } + if (message.dimParam != null && message.hasOwnProperty("dimParam")) { + object.dimParam = message.dimParam; + if (options.oneofs) + object.value = "dimParam"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + object.denotation = message.denotation; + return object; + }; + + /** + * Converts this Dimension to JSON. + * @function toJSON + * @memberof onnx.TensorShapeProto.Dimension + * @instance + * @returns {Object.} JSON object + */ + Dimension.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Dimension + * @function getTypeUrl + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Dimension.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorShapeProto.Dimension"; + }; + + return Dimension; + })(); + + return TensorShapeProto; + })(); + + onnx.TypeProto = (function() { + + /** + * Properties of a TypeProto. + * @memberof onnx + * @interface ITypeProto + * @property {onnx.TypeProto.ITensor|null} [tensorType] TypeProto tensorType + * @property {onnx.TypeProto.ISequence|null} [sequenceType] TypeProto sequenceType + * @property {onnx.TypeProto.IMap|null} [mapType] TypeProto mapType + * @property {onnx.TypeProto.IOptional|null} [optionalType] TypeProto optionalType + * @property {onnx.TypeProto.ISparseTensor|null} [sparseTensorType] TypeProto sparseTensorType + * @property {string|null} [denotation] TypeProto denotation + */ + + /** + * Constructs a new TypeProto. + * @memberof onnx + * @classdesc Represents a TypeProto. + * @implements ITypeProto + * @constructor + * @param {onnx.ITypeProto=} [properties] Properties to set + */ + function TypeProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TypeProto tensorType. + * @member {onnx.TypeProto.ITensor|null|undefined} tensorType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.tensorType = null; + + /** + * TypeProto sequenceType. + * @member {onnx.TypeProto.ISequence|null|undefined} sequenceType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.sequenceType = null; + + /** + * TypeProto mapType. + * @member {onnx.TypeProto.IMap|null|undefined} mapType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.mapType = null; + + /** + * TypeProto optionalType. + * @member {onnx.TypeProto.IOptional|null|undefined} optionalType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.optionalType = null; + + /** + * TypeProto sparseTensorType. + * @member {onnx.TypeProto.ISparseTensor|null|undefined} sparseTensorType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.sparseTensorType = null; + + /** + * TypeProto denotation. + * @member {string} denotation + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.denotation = ""; + + // OneOf field names bound to virtual getters and setters + var $oneOfFields; + + /** + * TypeProto value. + * @member {"tensorType"|"sequenceType"|"mapType"|"optionalType"|"sparseTensorType"|undefined} value + * @memberof onnx.TypeProto + * @instance + */ + Object.defineProperty(TypeProto.prototype, "value", { + get: $util.oneOfGetter($oneOfFields = ["tensorType", "sequenceType", "mapType", "optionalType", "sparseTensorType"]), + set: $util.oneOfSetter($oneOfFields) + }); + + /** + * Creates a new TypeProto instance using the specified properties. + * @function create + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto=} [properties] Properties to set + * @returns {onnx.TypeProto} TypeProto instance + */ + TypeProto.create = function create(properties) { + return new TypeProto(properties); + }; + + /** + * Encodes the specified TypeProto message. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto} message TypeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TypeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.tensorType != null && Object.hasOwnProperty.call(message, "tensorType")) + $root.onnx.TypeProto.Tensor.encode(message.tensorType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.sequenceType != null && Object.hasOwnProperty.call(message, "sequenceType")) + $root.onnx.TypeProto.Sequence.encode(message.sequenceType, writer.uint32(/* id 4, wireType 2 =*/34).fork()).ldelim(); + if (message.mapType != null && Object.hasOwnProperty.call(message, "mapType")) + $root.onnx.TypeProto.Map.encode(message.mapType, writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.denotation != null && Object.hasOwnProperty.call(message, "denotation")) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.denotation); + if (message.sparseTensorType != null && Object.hasOwnProperty.call(message, "sparseTensorType")) + $root.onnx.TypeProto.SparseTensor.encode(message.sparseTensorType, writer.uint32(/* id 8, wireType 2 =*/66).fork()).ldelim(); + if (message.optionalType != null && Object.hasOwnProperty.call(message, "optionalType")) + $root.onnx.TypeProto.Optional.encode(message.optionalType, writer.uint32(/* id 9, wireType 2 =*/74).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TypeProto message, length delimited. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto} message TypeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TypeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TypeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto} TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TypeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.tensorType = $root.onnx.TypeProto.Tensor.decode(reader, reader.uint32()); + break; + } + case 4: { + message.sequenceType = $root.onnx.TypeProto.Sequence.decode(reader, reader.uint32()); + break; + } + case 5: { + message.mapType = $root.onnx.TypeProto.Map.decode(reader, reader.uint32()); + break; + } + case 9: { + message.optionalType = $root.onnx.TypeProto.Optional.decode(reader, reader.uint32()); + break; + } + case 8: { + message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.decode(reader, reader.uint32()); + break; + } + case 6: { + message.denotation = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TypeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto} TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TypeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TypeProto message. + * @function verify + * @memberof onnx.TypeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TypeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + var properties = {}; + if (message.tensorType != null && message.hasOwnProperty("tensorType")) { + properties.value = 1; + { + var error = $root.onnx.TypeProto.Tensor.verify(message.tensorType); + if (error) + return "tensorType." + error; + } + } + if (message.sequenceType != null && message.hasOwnProperty("sequenceType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Sequence.verify(message.sequenceType); + if (error) + return "sequenceType." + error; + } + } + if (message.mapType != null && message.hasOwnProperty("mapType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Map.verify(message.mapType); + if (error) + return "mapType." + error; + } + } + if (message.optionalType != null && message.hasOwnProperty("optionalType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Optional.verify(message.optionalType); + if (error) + return "optionalType." + error; + } + } + if (message.sparseTensorType != null && message.hasOwnProperty("sparseTensorType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.SparseTensor.verify(message.sparseTensorType); + if (error) + return "sparseTensorType." + error; + } + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + if (!$util.isString(message.denotation)) + return "denotation: string expected"; + return null; + }; + + /** + * Creates a TypeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto} TypeProto + */ + TypeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto) + return object; + var message = new $root.onnx.TypeProto(); + if (object.tensorType != null) { + if (typeof object.tensorType !== "object") + throw TypeError(".onnx.TypeProto.tensorType: object expected"); + message.tensorType = $root.onnx.TypeProto.Tensor.fromObject(object.tensorType); + } + if (object.sequenceType != null) { + if (typeof object.sequenceType !== "object") + throw TypeError(".onnx.TypeProto.sequenceType: object expected"); + message.sequenceType = $root.onnx.TypeProto.Sequence.fromObject(object.sequenceType); + } + if (object.mapType != null) { + if (typeof object.mapType !== "object") + throw TypeError(".onnx.TypeProto.mapType: object expected"); + message.mapType = $root.onnx.TypeProto.Map.fromObject(object.mapType); + } + if (object.optionalType != null) { + if (typeof object.optionalType !== "object") + throw TypeError(".onnx.TypeProto.optionalType: object expected"); + message.optionalType = $root.onnx.TypeProto.Optional.fromObject(object.optionalType); + } + if (object.sparseTensorType != null) { + if (typeof object.sparseTensorType !== "object") + throw TypeError(".onnx.TypeProto.sparseTensorType: object expected"); + message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.fromObject(object.sparseTensorType); + } + if (object.denotation != null) + message.denotation = String(object.denotation); + return message; + }; + + /** + * Creates a plain object from a TypeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto + * @static + * @param {onnx.TypeProto} message TypeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TypeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.denotation = ""; + if (message.tensorType != null && message.hasOwnProperty("tensorType")) { + object.tensorType = $root.onnx.TypeProto.Tensor.toObject(message.tensorType, options); + if (options.oneofs) + object.value = "tensorType"; + } + if (message.sequenceType != null && message.hasOwnProperty("sequenceType")) { + object.sequenceType = $root.onnx.TypeProto.Sequence.toObject(message.sequenceType, options); + if (options.oneofs) + object.value = "sequenceType"; + } + if (message.mapType != null && message.hasOwnProperty("mapType")) { + object.mapType = $root.onnx.TypeProto.Map.toObject(message.mapType, options); + if (options.oneofs) + object.value = "mapType"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + object.denotation = message.denotation; + if (message.sparseTensorType != null && message.hasOwnProperty("sparseTensorType")) { + object.sparseTensorType = $root.onnx.TypeProto.SparseTensor.toObject(message.sparseTensorType, options); + if (options.oneofs) + object.value = "sparseTensorType"; + } + if (message.optionalType != null && message.hasOwnProperty("optionalType")) { + object.optionalType = $root.onnx.TypeProto.Optional.toObject(message.optionalType, options); + if (options.oneofs) + object.value = "optionalType"; + } + return object; + }; + + /** + * Converts this TypeProto to JSON. + * @function toJSON + * @memberof onnx.TypeProto + * @instance + * @returns {Object.} JSON object + */ + TypeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TypeProto + * @function getTypeUrl + * @memberof onnx.TypeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TypeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto"; + }; + + TypeProto.Tensor = (function() { + + /** + * Properties of a Tensor. + * @memberof onnx.TypeProto + * @interface ITensor + * @property {number|null} [elemType] Tensor elemType + * @property {onnx.ITensorShapeProto|null} [shape] Tensor shape + */ + + /** + * Constructs a new Tensor. + * @memberof onnx.TypeProto + * @classdesc Represents a Tensor. + * @implements ITensor + * @constructor + * @param {onnx.TypeProto.ITensor=} [properties] Properties to set + */ + function Tensor(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Tensor elemType. + * @member {number} elemType + * @memberof onnx.TypeProto.Tensor + * @instance + */ + Tensor.prototype.elemType = 0; + + /** + * Tensor shape. + * @member {onnx.ITensorShapeProto|null|undefined} shape + * @memberof onnx.TypeProto.Tensor + * @instance + */ + Tensor.prototype.shape = null; + + /** + * Creates a new Tensor instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor=} [properties] Properties to set + * @returns {onnx.TypeProto.Tensor} Tensor instance + */ + Tensor.create = function create(properties) { + return new Tensor(properties); + }; + + /** + * Encodes the specified Tensor message. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Tensor.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + writer.uint32(/* id 1, wireType 0 =*/8).int32(message.elemType); + if (message.shape != null && Object.hasOwnProperty.call(message, "shape")) + $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Tensor message, length delimited. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Tensor.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Tensor message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Tensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Tensor} Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Tensor.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Tensor(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = reader.int32(); + break; + } + case 2: { + message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Tensor message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Tensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Tensor} Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Tensor.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Tensor message. + * @function verify + * @memberof onnx.TypeProto.Tensor + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Tensor.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) + if (!$util.isInteger(message.elemType)) + return "elemType: integer expected"; + if (message.shape != null && message.hasOwnProperty("shape")) { + var error = $root.onnx.TensorShapeProto.verify(message.shape); + if (error) + return "shape." + error; + } + return null; + }; + + /** + * Creates a Tensor message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Tensor + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Tensor} Tensor + */ + Tensor.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Tensor) + return object; + var message = new $root.onnx.TypeProto.Tensor(); + if (object.elemType != null) + message.elemType = object.elemType | 0; + if (object.shape != null) { + if (typeof object.shape !== "object") + throw TypeError(".onnx.TypeProto.Tensor.shape: object expected"); + message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); + } + return message; + }; + + /** + * Creates a plain object from a Tensor message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.Tensor} message Tensor + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Tensor.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.elemType = 0; + object.shape = null; + } + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = message.elemType; + if (message.shape != null && message.hasOwnProperty("shape")) + object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); + return object; + }; + + /** + * Converts this Tensor to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Tensor + * @instance + * @returns {Object.} JSON object + */ + Tensor.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Tensor + * @function getTypeUrl + * @memberof onnx.TypeProto.Tensor + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Tensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Tensor"; + }; + + return Tensor; + })(); + + TypeProto.Sequence = (function() { + + /** + * Properties of a Sequence. + * @memberof onnx.TypeProto + * @interface ISequence + * @property {onnx.ITypeProto|null} [elemType] Sequence elemType + */ + + /** + * Constructs a new Sequence. + * @memberof onnx.TypeProto + * @classdesc Represents a Sequence. + * @implements ISequence + * @constructor + * @param {onnx.TypeProto.ISequence=} [properties] Properties to set + */ + function Sequence(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Sequence elemType. + * @member {onnx.ITypeProto|null|undefined} elemType + * @memberof onnx.TypeProto.Sequence + * @instance + */ + Sequence.prototype.elemType = null; + + /** + * Creates a new Sequence instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence=} [properties] Properties to set + * @returns {onnx.TypeProto.Sequence} Sequence instance + */ + Sequence.create = function create(properties) { + return new Sequence(properties); + }; + + /** + * Encodes the specified Sequence message. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Sequence.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Sequence message, length delimited. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Sequence.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Sequence message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Sequence + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Sequence} Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Sequence.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Sequence(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Sequence message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Sequence + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Sequence} Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Sequence.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Sequence message. + * @function verify + * @memberof onnx.TypeProto.Sequence + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Sequence.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) { + var error = $root.onnx.TypeProto.verify(message.elemType); + if (error) + return "elemType." + error; + } + return null; + }; + + /** + * Creates a Sequence message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Sequence + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Sequence} Sequence + */ + Sequence.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Sequence) + return object; + var message = new $root.onnx.TypeProto.Sequence(); + if (object.elemType != null) { + if (typeof object.elemType !== "object") + throw TypeError(".onnx.TypeProto.Sequence.elemType: object expected"); + message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); + } + return message; + }; + + /** + * Creates a plain object from a Sequence message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.Sequence} message Sequence + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Sequence.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.elemType = null; + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); + return object; + }; + + /** + * Converts this Sequence to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Sequence + * @instance + * @returns {Object.} JSON object + */ + Sequence.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Sequence + * @function getTypeUrl + * @memberof onnx.TypeProto.Sequence + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Sequence.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Sequence"; + }; + + return Sequence; + })(); + + TypeProto.Map = (function() { + + /** + * Properties of a Map. + * @memberof onnx.TypeProto + * @interface IMap + * @property {number|null} [keyType] Map keyType + * @property {onnx.ITypeProto|null} [valueType] Map valueType + */ + + /** + * Constructs a new Map. + * @memberof onnx.TypeProto + * @classdesc Represents a Map. + * @implements IMap + * @constructor + * @param {onnx.TypeProto.IMap=} [properties] Properties to set + */ + function Map(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Map keyType. + * @member {number} keyType + * @memberof onnx.TypeProto.Map + * @instance + */ + Map.prototype.keyType = 0; + + /** + * Map valueType. + * @member {onnx.ITypeProto|null|undefined} valueType + * @memberof onnx.TypeProto.Map + * @instance + */ + Map.prototype.valueType = null; + + /** + * Creates a new Map instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap=} [properties] Properties to set + * @returns {onnx.TypeProto.Map} Map instance + */ + Map.create = function create(properties) { + return new Map(properties); + }; + + /** + * Encodes the specified Map message. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap} message Map message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Map.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.keyType != null && Object.hasOwnProperty.call(message, "keyType")) + writer.uint32(/* id 1, wireType 0 =*/8).int32(message.keyType); + if (message.valueType != null && Object.hasOwnProperty.call(message, "valueType")) + $root.onnx.TypeProto.encode(message.valueType, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Map message, length delimited. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap} message Map message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Map.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Map message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Map + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Map} Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Map.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Map(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.keyType = reader.int32(); + break; + } + case 2: { + message.valueType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Map message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Map + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Map} Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Map.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Map message. + * @function verify + * @memberof onnx.TypeProto.Map + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Map.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.keyType != null && message.hasOwnProperty("keyType")) + if (!$util.isInteger(message.keyType)) + return "keyType: integer expected"; + if (message.valueType != null && message.hasOwnProperty("valueType")) { + var error = $root.onnx.TypeProto.verify(message.valueType); + if (error) + return "valueType." + error; + } + return null; + }; + + /** + * Creates a Map message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Map + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Map} Map + */ + Map.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Map) + return object; + var message = new $root.onnx.TypeProto.Map(); + if (object.keyType != null) + message.keyType = object.keyType | 0; + if (object.valueType != null) { + if (typeof object.valueType !== "object") + throw TypeError(".onnx.TypeProto.Map.valueType: object expected"); + message.valueType = $root.onnx.TypeProto.fromObject(object.valueType); + } + return message; + }; + + /** + * Creates a plain object from a Map message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.Map} message Map + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Map.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.keyType = 0; + object.valueType = null; + } + if (message.keyType != null && message.hasOwnProperty("keyType")) + object.keyType = message.keyType; + if (message.valueType != null && message.hasOwnProperty("valueType")) + object.valueType = $root.onnx.TypeProto.toObject(message.valueType, options); + return object; + }; + + /** + * Converts this Map to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Map + * @instance + * @returns {Object.} JSON object + */ + Map.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Map + * @function getTypeUrl + * @memberof onnx.TypeProto.Map + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Map.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Map"; + }; + + return Map; + })(); + + TypeProto.Optional = (function() { + + /** + * Properties of an Optional. + * @memberof onnx.TypeProto + * @interface IOptional + * @property {onnx.ITypeProto|null} [elemType] Optional elemType + */ + + /** + * Constructs a new Optional. + * @memberof onnx.TypeProto + * @classdesc Represents an Optional. + * @implements IOptional + * @constructor + * @param {onnx.TypeProto.IOptional=} [properties] Properties to set + */ + function Optional(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Optional elemType. + * @member {onnx.ITypeProto|null|undefined} elemType + * @memberof onnx.TypeProto.Optional + * @instance + */ + Optional.prototype.elemType = null; + + /** + * Creates a new Optional instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional=} [properties] Properties to set + * @returns {onnx.TypeProto.Optional} Optional instance + */ + Optional.create = function create(properties) { + return new Optional(properties); + }; + + /** + * Encodes the specified Optional message. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Optional.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Optional message, length delimited. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Optional.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an Optional message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Optional + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Optional} Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Optional.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Optional(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an Optional message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Optional + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Optional} Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Optional.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an Optional message. + * @function verify + * @memberof onnx.TypeProto.Optional + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Optional.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) { + var error = $root.onnx.TypeProto.verify(message.elemType); + if (error) + return "elemType." + error; + } + return null; + }; + + /** + * Creates an Optional message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Optional + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Optional} Optional + */ + Optional.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Optional) + return object; + var message = new $root.onnx.TypeProto.Optional(); + if (object.elemType != null) { + if (typeof object.elemType !== "object") + throw TypeError(".onnx.TypeProto.Optional.elemType: object expected"); + message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); + } + return message; + }; + + /** + * Creates a plain object from an Optional message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.Optional} message Optional + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Optional.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.elemType = null; + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); + return object; + }; + + /** + * Converts this Optional to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Optional + * @instance + * @returns {Object.} JSON object + */ + Optional.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Optional + * @function getTypeUrl + * @memberof onnx.TypeProto.Optional + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Optional.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Optional"; + }; + + return Optional; + })(); + + TypeProto.SparseTensor = (function() { + + /** + * Properties of a SparseTensor. + * @memberof onnx.TypeProto + * @interface ISparseTensor + * @property {number|null} [elemType] SparseTensor elemType + * @property {onnx.ITensorShapeProto|null} [shape] SparseTensor shape + */ + + /** + * Constructs a new SparseTensor. + * @memberof onnx.TypeProto + * @classdesc Represents a SparseTensor. + * @implements ISparseTensor + * @constructor + * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set + */ + function SparseTensor(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * SparseTensor elemType. + * @member {number} elemType + * @memberof onnx.TypeProto.SparseTensor + * @instance + */ + SparseTensor.prototype.elemType = 0; + + /** + * SparseTensor shape. + * @member {onnx.ITensorShapeProto|null|undefined} shape + * @memberof onnx.TypeProto.SparseTensor + * @instance + */ + SparseTensor.prototype.shape = null; + + /** + * Creates a new SparseTensor instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set + * @returns {onnx.TypeProto.SparseTensor} SparseTensor instance + */ + SparseTensor.create = function create(properties) { + return new SparseTensor(properties); + }; + + /** + * Encodes the specified SparseTensor message. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensor.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + writer.uint32(/* id 1, wireType 0 =*/8).int32(message.elemType); + if (message.shape != null && Object.hasOwnProperty.call(message, "shape")) + $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified SparseTensor message, length delimited. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensor.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a SparseTensor message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensor.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.SparseTensor(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = reader.int32(); + break; + } + case 2: { + message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a SparseTensor message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensor.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a SparseTensor message. + * @function verify + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + SparseTensor.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) + if (!$util.isInteger(message.elemType)) + return "elemType: integer expected"; + if (message.shape != null && message.hasOwnProperty("shape")) { + var error = $root.onnx.TensorShapeProto.verify(message.shape); + if (error) + return "shape." + error; + } + return null; + }; + + /** + * Creates a SparseTensor message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + */ + SparseTensor.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.SparseTensor) + return object; + var message = new $root.onnx.TypeProto.SparseTensor(); + if (object.elemType != null) + message.elemType = object.elemType | 0; + if (object.shape != null) { + if (typeof object.shape !== "object") + throw TypeError(".onnx.TypeProto.SparseTensor.shape: object expected"); + message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); + } + return message; + }; + + /** + * Creates a plain object from a SparseTensor message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.SparseTensor} message SparseTensor + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + SparseTensor.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.elemType = 0; + object.shape = null; + } + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = message.elemType; + if (message.shape != null && message.hasOwnProperty("shape")) + object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); + return object; + }; + + /** + * Converts this SparseTensor to JSON. + * @function toJSON + * @memberof onnx.TypeProto.SparseTensor + * @instance + * @returns {Object.} JSON object + */ + SparseTensor.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for SparseTensor + * @function getTypeUrl + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + SparseTensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.SparseTensor"; + }; + + return SparseTensor; + })(); + + return TypeProto; + })(); + + onnx.OperatorSetIdProto = (function() { + + /** + * Properties of an OperatorSetIdProto. + * @memberof onnx + * @interface IOperatorSetIdProto + * @property {string|null} [domain] OperatorSetIdProto domain + * @property {number|Long|null} [version] OperatorSetIdProto version + */ + + /** + * Constructs a new OperatorSetIdProto. + * @memberof onnx + * @classdesc Represents an OperatorSetIdProto. + * @implements IOperatorSetIdProto + * @constructor + * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set + */ + function OperatorSetIdProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * OperatorSetIdProto domain. + * @member {string} domain + * @memberof onnx.OperatorSetIdProto + * @instance + */ + OperatorSetIdProto.prototype.domain = ""; + + /** + * OperatorSetIdProto version. + * @member {number|Long} version + * @memberof onnx.OperatorSetIdProto + * @instance + */ + OperatorSetIdProto.prototype.version = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * Creates a new OperatorSetIdProto instance using the specified properties. + * @function create + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto instance + */ + OperatorSetIdProto.create = function create(properties) { + return new OperatorSetIdProto(properties); + }; + + /** + * Encodes the specified OperatorSetIdProto message. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. + * @function encode + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + OperatorSetIdProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.domain); + if (message.version != null && Object.hasOwnProperty.call(message, "version")) + writer.uint32(/* id 2, wireType 0 =*/16).int64(message.version); + return writer; + }; + + /** + * Encodes the specified OperatorSetIdProto message, length delimited. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + OperatorSetIdProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.OperatorSetIdProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + OperatorSetIdProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.OperatorSetIdProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.domain = reader.string(); + break; + } + case 2: { + message.version = reader.int64(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.OperatorSetIdProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + OperatorSetIdProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an OperatorSetIdProto message. + * @function verify + * @memberof onnx.OperatorSetIdProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + OperatorSetIdProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + if (message.version != null && message.hasOwnProperty("version")) + if (!$util.isInteger(message.version) && !(message.version && $util.isInteger(message.version.low) && $util.isInteger(message.version.high))) + return "version: integer|Long expected"; + return null; + }; + + /** + * Creates an OperatorSetIdProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.OperatorSetIdProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + */ + OperatorSetIdProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.OperatorSetIdProto) + return object; + var message = new $root.onnx.OperatorSetIdProto(); + if (object.domain != null) + message.domain = String(object.domain); + if (object.version != null) + if ($util.Long) + (message.version = $util.Long.fromValue(object.version)).unsigned = false; + else if (typeof object.version === "string") + message.version = parseInt(object.version, 10); + else if (typeof object.version === "number") + message.version = object.version; + else if (typeof object.version === "object") + message.version = new $util.LongBits(object.version.low >>> 0, object.version.high >>> 0).toNumber(); + return message; + }; + + /** + * Creates a plain object from an OperatorSetIdProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.OperatorSetIdProto} message OperatorSetIdProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + OperatorSetIdProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.domain = ""; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.version = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.version = options.longs === String ? "0" : 0; + } + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.version != null && message.hasOwnProperty("version")) + if (typeof message.version === "number") + object.version = options.longs === String ? String(message.version) : message.version; + else + object.version = options.longs === String ? $util.Long.prototype.toString.call(message.version) : options.longs === Number ? new $util.LongBits(message.version.low >>> 0, message.version.high >>> 0).toNumber() : message.version; + return object; + }; + + /** + * Converts this OperatorSetIdProto to JSON. + * @function toJSON + * @memberof onnx.OperatorSetIdProto + * @instance + * @returns {Object.} JSON object + */ + OperatorSetIdProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for OperatorSetIdProto + * @function getTypeUrl + * @memberof onnx.OperatorSetIdProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + OperatorSetIdProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.OperatorSetIdProto"; + }; + + return OperatorSetIdProto; + })(); + + /** + * OperatorStatus enum. + * @name onnx.OperatorStatus + * @enum {number} + * @property {number} EXPERIMENTAL=0 EXPERIMENTAL value + * @property {number} STABLE=1 STABLE value + */ + onnx.OperatorStatus = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "EXPERIMENTAL"] = 0; + values[valuesById[1] = "STABLE"] = 1; + return values; + })(); + + onnx.FunctionProto = (function() { + + /** + * Properties of a FunctionProto. + * @memberof onnx + * @interface IFunctionProto + * @property {string|null} [name] FunctionProto name + * @property {Array.|null} [input] FunctionProto input + * @property {Array.|null} [output] FunctionProto output + * @property {Array.|null} [attribute] FunctionProto attribute + * @property {Array.|null} [attributeProto] FunctionProto attributeProto + * @property {Array.|null} [node] FunctionProto node + * @property {string|null} [docString] FunctionProto docString + * @property {Array.|null} [opsetImport] FunctionProto opsetImport + * @property {string|null} [domain] FunctionProto domain + */ + + /** + * Constructs a new FunctionProto. + * @memberof onnx + * @classdesc Represents a FunctionProto. + * @implements IFunctionProto + * @constructor + * @param {onnx.IFunctionProto=} [properties] Properties to set + */ + function FunctionProto(properties) { + this.input = []; + this.output = []; + this.attribute = []; + this.attributeProto = []; + this.node = []; + this.opsetImport = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * FunctionProto name. + * @member {string} name + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.name = ""; + + /** + * FunctionProto input. + * @member {Array.} input + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.input = $util.emptyArray; + + /** + * FunctionProto output. + * @member {Array.} output + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.output = $util.emptyArray; + + /** + * FunctionProto attribute. + * @member {Array.} attribute + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.attribute = $util.emptyArray; + + /** + * FunctionProto attributeProto. + * @member {Array.} attributeProto + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.attributeProto = $util.emptyArray; + + /** + * FunctionProto node. + * @member {Array.} node + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.node = $util.emptyArray; + + /** + * FunctionProto docString. + * @member {string} docString + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.docString = ""; + + /** + * FunctionProto opsetImport. + * @member {Array.} opsetImport + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.opsetImport = $util.emptyArray; + + /** + * FunctionProto domain. + * @member {string} domain + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.domain = ""; + + /** + * Creates a new FunctionProto instance using the specified properties. + * @function create + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto=} [properties] Properties to set + * @returns {onnx.FunctionProto} FunctionProto instance + */ + FunctionProto.create = function create(properties) { + return new FunctionProto(properties); + }; + + /** + * Encodes the specified FunctionProto message. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. + * @function encode + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + FunctionProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + writer.uint32(/* id 4, wireType 2 =*/34).string(message.input[i]); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + writer.uint32(/* id 5, wireType 2 =*/42).string(message.output[i]); + if (message.attribute != null && message.attribute.length) + for (var i = 0; i < message.attribute.length; ++i) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.attribute[i]); + if (message.node != null && message.node.length) + for (var i = 0; i < message.node.length; ++i) + $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 7, wireType 2 =*/58).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 8, wireType 2 =*/66).string(message.docString); + if (message.opsetImport != null && message.opsetImport.length) + for (var i = 0; i < message.opsetImport.length; ++i) + $root.onnx.OperatorSetIdProto.encode(message.opsetImport[i], writer.uint32(/* id 9, wireType 2 =*/74).fork()).ldelim(); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 10, wireType 2 =*/82).string(message.domain); + if (message.attributeProto != null && message.attributeProto.length) + for (var i = 0; i < message.attributeProto.length; ++i) + $root.onnx.AttributeProto.encode(message.attributeProto[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified FunctionProto message, length delimited. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + FunctionProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a FunctionProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.FunctionProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.FunctionProto} FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + FunctionProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.FunctionProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 4: { + if (!(message.input && message.input.length)) + message.input = []; + message.input.push(reader.string()); + break; + } + case 5: { + if (!(message.output && message.output.length)) + message.output = []; + message.output.push(reader.string()); + break; + } + case 6: { + if (!(message.attribute && message.attribute.length)) + message.attribute = []; + message.attribute.push(reader.string()); + break; + } + case 11: { + if (!(message.attributeProto && message.attributeProto.length)) + message.attributeProto = []; + message.attributeProto.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); + break; + } + case 7: { + if (!(message.node && message.node.length)) + message.node = []; + message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); + break; + } + case 8: { + message.docString = reader.string(); + break; + } + case 9: { + if (!(message.opsetImport && message.opsetImport.length)) + message.opsetImport = []; + message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); + break; + } + case 10: { + message.domain = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a FunctionProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.FunctionProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.FunctionProto} FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + FunctionProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a FunctionProto message. + * @function verify + * @memberof onnx.FunctionProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + FunctionProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) + return "input: array expected"; + for (var i = 0; i < message.input.length; ++i) + if (!$util.isString(message.input[i])) + return "input: string[] expected"; + } + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) + return "output: array expected"; + for (var i = 0; i < message.output.length; ++i) + if (!$util.isString(message.output[i])) + return "output: string[] expected"; + } + if (message.attribute != null && message.hasOwnProperty("attribute")) { + if (!Array.isArray(message.attribute)) + return "attribute: array expected"; + for (var i = 0; i < message.attribute.length; ++i) + if (!$util.isString(message.attribute[i])) + return "attribute: string[] expected"; + } + if (message.attributeProto != null && message.hasOwnProperty("attributeProto")) { + if (!Array.isArray(message.attributeProto)) + return "attributeProto: array expected"; + for (var i = 0; i < message.attributeProto.length; ++i) { + var error = $root.onnx.AttributeProto.verify(message.attributeProto[i]); + if (error) + return "attributeProto." + error; + } + } + if (message.node != null && message.hasOwnProperty("node")) { + if (!Array.isArray(message.node)) + return "node: array expected"; + for (var i = 0; i < message.node.length; ++i) { + var error = $root.onnx.NodeProto.verify(message.node[i]); + if (error) + return "node." + error; + } + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.opsetImport != null && message.hasOwnProperty("opsetImport")) { + if (!Array.isArray(message.opsetImport)) + return "opsetImport: array expected"; + for (var i = 0; i < message.opsetImport.length; ++i) { + var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); + if (error) + return "opsetImport." + error; + } + } + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + return null; + }; + + /** + * Creates a FunctionProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.FunctionProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.FunctionProto} FunctionProto + */ + FunctionProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.FunctionProto) + return object; + var message = new $root.onnx.FunctionProto(); + if (object.name != null) + message.name = String(object.name); + if (object.input) { + if (!Array.isArray(object.input)) + throw TypeError(".onnx.FunctionProto.input: array expected"); + message.input = []; + for (var i = 0; i < object.input.length; ++i) + message.input[i] = String(object.input[i]); + } + if (object.output) { + if (!Array.isArray(object.output)) + throw TypeError(".onnx.FunctionProto.output: array expected"); + message.output = []; + for (var i = 0; i < object.output.length; ++i) + message.output[i] = String(object.output[i]); + } + if (object.attribute) { + if (!Array.isArray(object.attribute)) + throw TypeError(".onnx.FunctionProto.attribute: array expected"); + message.attribute = []; + for (var i = 0; i < object.attribute.length; ++i) + message.attribute[i] = String(object.attribute[i]); + } + if (object.attributeProto) { + if (!Array.isArray(object.attributeProto)) + throw TypeError(".onnx.FunctionProto.attributeProto: array expected"); + message.attributeProto = []; + for (var i = 0; i < object.attributeProto.length; ++i) { + if (typeof object.attributeProto[i] !== "object") + throw TypeError(".onnx.FunctionProto.attributeProto: object expected"); + message.attributeProto[i] = $root.onnx.AttributeProto.fromObject(object.attributeProto[i]); + } + } + if (object.node) { + if (!Array.isArray(object.node)) + throw TypeError(".onnx.FunctionProto.node: array expected"); + message.node = []; + for (var i = 0; i < object.node.length; ++i) { + if (typeof object.node[i] !== "object") + throw TypeError(".onnx.FunctionProto.node: object expected"); + message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); + } + } + if (object.docString != null) + message.docString = String(object.docString); + if (object.opsetImport) { + if (!Array.isArray(object.opsetImport)) + throw TypeError(".onnx.FunctionProto.opsetImport: array expected"); + message.opsetImport = []; + for (var i = 0; i < object.opsetImport.length; ++i) { + if (typeof object.opsetImport[i] !== "object") + throw TypeError(".onnx.FunctionProto.opsetImport: object expected"); + message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); + } + } + if (object.domain != null) + message.domain = String(object.domain); + return message; + }; + + /** + * Creates a plain object from a FunctionProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.FunctionProto + * @static + * @param {onnx.FunctionProto} message FunctionProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + FunctionProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.input = []; + object.output = []; + object.attribute = []; + object.node = []; + object.opsetImport = []; + object.attributeProto = []; + } + if (options.defaults) { + object.name = ""; + object.docString = ""; + object.domain = ""; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = message.input[j]; + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = message.output[j]; + } + if (message.attribute && message.attribute.length) { + object.attribute = []; + for (var j = 0; j < message.attribute.length; ++j) + object.attribute[j] = message.attribute[j]; + } + if (message.node && message.node.length) { + object.node = []; + for (var j = 0; j < message.node.length; ++j) + object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.opsetImport && message.opsetImport.length) { + object.opsetImport = []; + for (var j = 0; j < message.opsetImport.length; ++j) + object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); + } + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.attributeProto && message.attributeProto.length) { + object.attributeProto = []; + for (var j = 0; j < message.attributeProto.length; ++j) + object.attributeProto[j] = $root.onnx.AttributeProto.toObject(message.attributeProto[j], options); + } + return object; + }; + + /** + * Converts this FunctionProto to JSON. + * @function toJSON + * @memberof onnx.FunctionProto + * @instance + * @returns {Object.} JSON object + */ + FunctionProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for FunctionProto + * @function getTypeUrl + * @memberof onnx.FunctionProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + FunctionProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.FunctionProto"; + }; + + return FunctionProto; + })(); + + return onnx; +})(); + +module.exports = $root; diff --git a/js/node/test/test-utils.ts b/js/node/test/test-utils.ts index 968e8a1881810..3eef90356a335 100644 --- a/js/node/test/test-utils.ts +++ b/js/node/test/test-utils.ts @@ -4,10 +4,11 @@ import assert from 'assert'; import * as fs from 'fs-extra'; import {jsonc} from 'jsonc'; -import * as onnx_proto from 'onnx-proto'; import {InferenceSession, Tensor} from 'onnxruntime-common'; import * as path from 'path'; +import * as onnx_proto from './ort-schema/protobuf/onnx'; + export const TEST_ROOT = __dirname; export const TEST_DATA_ROOT = path.join(TEST_ROOT, 'testdata'); diff --git a/js/package-lock.json b/js/package-lock.json index c87a58a3196d6..c16a8b59a3a6f 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -3391,9 +3391,9 @@ } }, "node_modules/normalize-package-data/node_modules/semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true, "bin": { "semver": "bin/semver" @@ -7011,9 +7011,9 @@ }, "dependencies": { "semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true } } diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 0b82a9c031baa..b246e19137888 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -20,6 +20,7 @@ Do not modify directly.* | Asinh | ai.onnx(9+) | | | Atan | ai.onnx(7+) | | | Atanh | ai.onnx(9+) | | +| Attention | com.microsoft(1+) | need implementing mask and past/present | | AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(7-9,10,11+) | need perf optimization; need implementing activation | | BiasAdd | com.microsoft(1+) | | | BiasSplitGelu | com.microsoft(1+) | | @@ -61,6 +62,7 @@ Do not modify directly.* | MemcpyFromHost | ai.onnx(1+) | | | MemcpyToHost | ai.onnx(1+) | | | Mul | ai.onnx(7-12,13,14+) | | +| MultiHeadAttention | com.microsoft(1+) | need implementing mask and past/present | | Neg | ai.onnx(6-12,13+) | | | Not | ai.onnx(1+) | | | Pad | ai.onnx(2-10,11-12,13-17,18,19+) | | diff --git a/js/web/lib/onnxjs/backends/webgl/glsl-coordinate-lib.ts b/js/web/lib/onnxjs/backends/webgl/glsl-coordinate-lib.ts index dd3f1b30dfb46..1f2b27c7bdea8 100644 --- a/js/web/lib/onnxjs/backends/webgl/glsl-coordinate-lib.ts +++ b/js/web/lib/onnxjs/backends/webgl/glsl-coordinate-lib.ts @@ -186,7 +186,7 @@ export class CoordsGlslLib extends GlslLib { /** * 1D packed output coordinates. */ - protected getOutputPacked1DCoords(shape: [number], texShape: [number, number]): GlslLibRoutine { + protected getOutputPacked1DCoords(_shape: [number], texShape: [number, number]): GlslLibRoutine { const packedTexShape = texShape; let source = ''; if (packedTexShape[0] === 1) { @@ -331,7 +331,7 @@ export class CoordsGlslLib extends GlslLib { /** * Unpacked 1D output coordinates. */ - protected getOutputUnpacked1DCoords(shape: [number], texShape: [number, number]): GlslLibRoutine { + protected getOutputUnpacked1DCoords(_shape: [number], texShape: [number, number]): GlslLibRoutine { const source = ` int getOutputCoords() { ivec2 resTexRC = ivec2(TexCoords.xy * @@ -641,7 +641,7 @@ export class CoordsGlslLib extends GlslLib { if (outRank < 2 && inRank > 0) { unpackedCoordsSnippet = 'coords'; } else { - unpackedCoordsSnippet = inShape.map((s, i) => `coords.${fields[i + rankDiff]}`).join(', '); + unpackedCoordsSnippet = inShape.map((_s, i) => `coords.${fields[i + rankDiff]}`).join(', '); } let output = 'return outputValue;'; @@ -734,7 +734,7 @@ export class CoordsGlslLib extends GlslLib { if (outRank < 2 && inRank > 0) { unpackedCoordsSnippet = 'coords'; } else { - unpackedCoordsSnippet = inputLayout.unpackedShape.map((s, i) => `coords.${fields[i + rankDiff]}`).join(', '); + unpackedCoordsSnippet = inputLayout.unpackedShape.map((_s, i) => `coords.${fields[i + rankDiff]}`).join(', '); } const source = ` float ${funcName}() { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/concat-packed.ts b/js/web/lib/onnxjs/backends/webgl/ops/concat-packed.ts index 709f883ae12c9..d0e589a428825 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/concat-packed.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/concat-packed.ts @@ -12,7 +12,7 @@ import {getChannels, unpackFromChannel} from './packing-utils'; const createPackedConcatProgramMetadata = (inputCount: number, cacheHint: string) => ({ name: 'Concat (packed)', - inputNames: Array.from({length: inputCount}, (v, i) => `X${i}`), + inputNames: Array.from({length: inputCount}, (_v, i) => `X${i}`), inputTypes: Array(inputCount).fill(TextureType.packed), cacheHint }); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/concat.ts b/js/web/lib/onnxjs/backends/webgl/ops/concat.ts index c2b18ef86f814..f85f4032feae1 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/concat.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/concat.ts @@ -30,13 +30,13 @@ export const concat: OperatorImplementation = const createUnpackedConcatProgramMetadata = (inputCount: number, cacheHint: string) => ({ name: 'Concat', - inputNames: Array.from({length: inputCount}, (v, i) => `X${i}`), + inputNames: Array.from({length: inputCount}, (_v, i) => `X${i}`), inputTypes: Array(inputCount).fill(TextureType.unpacked), cacheHint }); const createUnpackedConcatProgramInfo = - (handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], axis: number): ProgramInfo => { + (_handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], axis: number): ProgramInfo => { const inputShape = inputs[0].dims.slice(); if (axis >= inputShape.length || axis < (-1 * inputShape.length)) { throw new Error('axis specified for concat doesn\'t match input dimensionality'); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/gather.ts b/js/web/lib/onnxjs/backends/webgl/ops/gather.ts index 54b6ccd1a3685..bb44a20d75f34 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/gather.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/gather.ts @@ -30,7 +30,7 @@ const gatherProgramMetadata = { }; const createGatherProgramInfo = - (handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], axis: number): ProgramInfo => { + (_handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], axis: number): ProgramInfo => { const inputShape = inputs[0].dims.slice(); const indexDataShape = inputs[1].dims.slice(); const outputShape = new Array(inputShape.length + indexDataShape.length - 1); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/im2col.ts b/js/web/lib/onnxjs/backends/webgl/ops/im2col.ts index f74c35b612665..a1da13ec48d70 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/im2col.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/im2col.ts @@ -15,7 +15,7 @@ const createIm2ColProgramMetadata = (cacheHint: string) => ({ }); const createIm2ColProgramInfo = - (inferenceHandler: WebGLInferenceHandler, metadata: ProgramMetadata, x: Tensor, w: Tensor, + (_inferenceHandler: WebGLInferenceHandler, metadata: ProgramMetadata, x: Tensor, w: Tensor, outputShape: readonly number[], attributes: ConvAttributes): ProgramInfo => { const xshape = x.dims; const wshape = w.dims; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts b/js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts index 1cd5288251433..efc79f686c960 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts @@ -35,7 +35,7 @@ const imageScalerProgramMetadata = { }; const createImageScalerProgramInfo = - (handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], attributes: ImageScalerAttributes): + (_handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], attributes: ImageScalerAttributes): ProgramInfo => { const outputShape = inputs[0].dims.slice(); const rank = outputShape.length; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts b/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts index fb3c2357ae8fe..0be6d1ba8bcd2 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts @@ -107,10 +107,10 @@ function getBcastSamplerForMatmul( const rankADiff = outRank - inARank; const rankBDiff = outRank - inBRank; - unpackedACoordsSnippet = inAShape.map((s, i) => `coords.${allGlChannels[i + rankADiff]}`); + unpackedACoordsSnippet = inAShape.map((_s, i) => `coords.${allGlChannels[i + rankADiff]}`); unpackedACoordsSnippet[inARank - 1] = 'i*2'; unpackedACoordsSnippet.join(', '); - unpackedBCoordsSnippet = inBShape.map((s, i) => `coords.${allGlChannels[i + rankBDiff]}`); + unpackedBCoordsSnippet = inBShape.map((_s, i) => `coords.${allGlChannels[i + rankBDiff]}`); unpackedBCoordsSnippet[inBRank - 2] = 'i*2'; unpackedBCoordsSnippet.join(', '); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/matmul.ts b/js/web/lib/onnxjs/backends/webgl/ops/matmul.ts index 704128fb4858e..523165f29f852 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/matmul.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/matmul.ts @@ -117,7 +117,7 @@ export function getBiasForMatmul( if (outRank < 2 && inRank > 0) { unpackedCoordsSnippet = 'coords'; } else { - unpackedCoordsSnippet = inShape.map((s, i) => `coords.${allGlChannels[i + rankDiff]}`).join(', '); + unpackedCoordsSnippet = inShape.map((_s, i) => `coords.${allGlChannels[i + rankDiff]}`).join(', '); } const broadcastDims = BroadcastUtil.getBroadcastDims(inShape, outShape); const coordsSnippet = broadcastDims.map(d => `coords.${allGlChannels[d + rankDiff]} = 0;`).join('\n'); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts b/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts index 1a2bc7422c833..c9ea460a6f1fc 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/reduce.ts @@ -46,7 +46,7 @@ export const parseReduceAttributes: OperatorInitialization = ( }; const createReduceProgramInfo = - (handler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, name: string, reduceOp: ReduceOp, + (_handler: WebGLInferenceHandler, inputs: Tensor[], attributes: ReduceAttributes, _name: string, reduceOp: ReduceOp, reduceProgramMetadata: ProgramMetadata): ProgramInfo => { const outputShape: number[] = []; const iRank = inputs[0].dims.length || 1; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/shape.ts b/js/web/lib/onnxjs/backends/webgl/ops/shape.ts index 51acf5042d8bd..c2d703ed04fa0 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/shape.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/shape.ts @@ -4,7 +4,7 @@ import {Tensor} from '../../../tensor'; import {WebGLInferenceHandler} from '../inference-handler'; -export const shape = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => { +export const shape = (_inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => { validateInputs(inputs); return [new Tensor([inputs[0].dims.length], 'int32', undefined, undefined, new Int32Array(inputs[0].dims))]; }; @@ -13,4 +13,4 @@ const validateInputs = (inputs: Tensor[]): void => { if (!inputs || inputs.length !== 1) { throw new Error('Shape requires 1 input.'); } -}; \ No newline at end of file +}; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/slice.ts b/js/web/lib/onnxjs/backends/webgl/ops/slice.ts index d32a76bbc8628..81fc1b7076fdb 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/slice.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/slice.ts @@ -42,8 +42,8 @@ export const parseSliceAttributes: OperatorInitialization = (no }; const createSliceProgramInfo = - (inferenceHandler: WebGLInferenceHandler, input: Tensor, attributes: SliceAttributes): ProgramInfo => { - const axes = (attributes.axes.length === 0) ? input.dims.slice(0).map((val, i) => i) : attributes.axes; + (_inferenceHandler: WebGLInferenceHandler, input: Tensor, attributes: SliceAttributes): ProgramInfo => { + const axes = (attributes.axes.length === 0) ? input.dims.slice(0).map((_val, i) => i) : attributes.axes; const normalizedAxes = ShapeUtil.normalizeAxes(axes, input.dims.length); const starts = attributes.starts.map((start, i) => { if (start > input.dims[normalizedAxes[i]] - 1) { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/split.ts b/js/web/lib/onnxjs/backends/webgl/ops/split.ts index d1bd00d47eebd..2ab14563d80e2 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/split.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/split.ts @@ -49,13 +49,13 @@ export const parseSplitAttributes: OperatorInitialization = (no }; const getProgramCount = - (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], axis: number, attributes: SplitAttributes): number => { + (_inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], axis: number, attributes: SplitAttributes): number => { const [, offsets] = SplitUtil.splitShape(inputs[0].dims, axis, attributes.split, attributes.numOutputs); return offsets.length; }; const createSplitProgramInfo = - (inferenceHandler: WebGLInferenceHandler, input: Tensor, attributes: SplitAttributes, axis: number, index: number): + (_inferenceHandler: WebGLInferenceHandler, input: Tensor, attributes: SplitAttributes, axis: number, index: number): ProgramInfo => { const [shapes, offsets] = SplitUtil.splitShape(input.dims, axis, attributes.split, attributes.numOutputs); const offset = offsets[index]; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/sum.ts b/js/web/lib/onnxjs/backends/webgl/ops/sum.ts index c05286d16f936..2c25b10c5872c 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/sum.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/sum.ts @@ -11,7 +11,7 @@ export const sum = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): const sumProgramMetadata = { name: 'Sum', - inputNames: inputs.map((v, i) => `X${i}`), + inputNames: inputs.map((_v, i) => `X${i}`), inputTypes: new Array(inputs.length).fill(TextureType.unpacked) }; @@ -24,7 +24,7 @@ const createSumProgramInfo = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], sumProgramMetadata: ProgramMetadata): ProgramInfo => { const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); const outputShape = inputs[0].dims.slice(); - const sumLine = inputs.map((v, i) => `${glsl.texture2D}(X${i},TexCoords)`).join(' + '); + const sumLine = inputs.map((_v, i) => `${glsl.texture2D}(X${i},TexCoords)`).join(' + '); const shaderSource = ` void main() { vec4 result = ${sumLine}; @@ -65,4 +65,4 @@ const validateInputs = (inputs: Tensor[]): void => { throw new Error('Input types are not matched.'); } } -}; \ No newline at end of file +}; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/tile.ts b/js/web/lib/onnxjs/backends/webgl/ops/tile.ts index 42128c7abc48c..1d2cba7d9d75f 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/tile.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/tile.ts @@ -22,7 +22,7 @@ export const tile = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): }; const createTileProgramInfo = - (handler: WebGLInferenceHandler, inputs: Tensor[], tileProgramMetadata: ProgramMetadata): ProgramInfo => { + (_handler: WebGLInferenceHandler, inputs: Tensor[], tileProgramMetadata: ProgramMetadata): ProgramInfo => { const inputShape = inputs[0].dims.slice(); const outputShape = new Array(inputShape.length); @@ -63,4 +63,4 @@ const validateInputs = (inputs: Tensor[]): void => { if (inputs[1].type !== 'int32' && inputs[1].type !== 'int16') { throw new Error('Invalid repeat type.'); } -}; \ No newline at end of file +}; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/transpose.ts b/js/web/lib/onnxjs/backends/webgl/ops/transpose.ts index 815ff13f1f925..d3e7b3c0823be 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/transpose.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/transpose.ts @@ -36,7 +36,7 @@ export const parseTransposeAttributes: OperatorInitialization createAttributeWithCacheKey({perm: node.attributes.getInts('perm', [])}); const createTransposeProgramInfo = - (inferenceHandler: WebGLInferenceHandler, input: Tensor, perm: number[]): ProgramInfo => { + (_inferenceHandler: WebGLInferenceHandler, input: Tensor, perm: number[]): ProgramInfo => { const inputShape = input.dims; perm = getAdjustedPerm(inputShape, perm); const unpackedOutputShape = getOutputShape(inputShape, perm); diff --git a/js/web/lib/onnxjs/backends/webgl/texture-data-encoder.ts b/js/web/lib/onnxjs/backends/webgl/texture-data-encoder.ts index 5135c63702316..4b0cf3f037921 100644 --- a/js/web/lib/onnxjs/backends/webgl/texture-data-encoder.ts +++ b/js/web/lib/onnxjs/backends/webgl/texture-data-encoder.ts @@ -82,7 +82,7 @@ export class RedFloat32DataEncoder implements DataEncoder { } decode(buffer: Encoder.DataArrayType, dataSize: number): Float32Array { if (this.channelSize === 1) { - const filteredData = (buffer as Float32Array).filter((value, index) => index % 4 === 0).subarray(0, dataSize); + const filteredData = (buffer as Float32Array).filter((_value, index) => index % 4 === 0).subarray(0, dataSize); return filteredData; } return buffer.subarray(0, dataSize) as Float32Array; @@ -119,7 +119,7 @@ export class RGBAFloatDataEncoder implements DataEncoder { } decode(buffer: Encoder.DataArrayType, dataSize: number): Float32Array { if (this.channelSize === 1) { - const filteredData = (buffer as Float32Array).filter((value, index) => index % 4 === 0).subarray(0, dataSize); + const filteredData = (buffer as Float32Array).filter((_value, index) => index % 4 === 0).subarray(0, dataSize); return filteredData; } return buffer.subarray(0, dataSize) as Float32Array; diff --git a/js/web/lib/onnxjs/backends/webgl/texture-layout-strategy.ts b/js/web/lib/onnxjs/backends/webgl/texture-layout-strategy.ts index c89ef3d23638d..f8e370747928c 100644 --- a/js/web/lib/onnxjs/backends/webgl/texture-layout-strategy.ts +++ b/js/web/lib/onnxjs/backends/webgl/texture-layout-strategy.ts @@ -105,7 +105,7 @@ export class PreferLogicalStrategy implements TextureLayoutStrategy { // tensor has 3 rows, we pretend it has 4 rows in order to account for the // fact that the texels containing the third row are half empty. logShape = logShape.map( - (d, i) => i >= logShape.length - 2 ? (logShape[i] % 2 === 0 ? logShape[i] : logShape[i] + 1) : logShape[i]); + (_d, i) => i >= logShape.length - 2 ? (logShape[i] % 2 === 0 ? logShape[i] : logShape[i] + 1) : logShape[i]); // Packed texture height is at least 2 (the channel height of a single // texel). @@ -182,7 +182,7 @@ export function parseAxisParam(axis: number|number[], shape: number[]): number[] const rank = shape.length; // Normalize input - axis = axis == null ? shape.map((s, i) => i) : ([] as number[]).concat(axis); + axis = axis == null ? shape.map((_s, i) => i) : ([] as number[]).concat(axis); // Check for valid range assert( diff --git a/js/web/lib/onnxjs/backends/webgl/texture-manager.ts b/js/web/lib/onnxjs/backends/webgl/texture-manager.ts index 1cb113e5fa630..effb65288dc1c 100644 --- a/js/web/lib/onnxjs/backends/webgl/texture-manager.ts +++ b/js/web/lib/onnxjs/backends/webgl/texture-manager.ts @@ -172,7 +172,7 @@ export class TextureManager { throw new Error(`TensorData type ${dataType} is not supported`); } } - toTextureData(dataType: Tensor.DataType, data: Tensor.NumberType|undefined): Encoder.DataArrayType|undefined { + toTextureData(_dataType: Tensor.DataType, data: Tensor.NumberType|undefined): Encoder.DataArrayType|undefined { if (!data) { return undefined; } diff --git a/js/web/lib/onnxjs/execution-plan.ts b/js/web/lib/onnxjs/execution-plan.ts index b95e639817dbf..5599087ab46f5 100644 --- a/js/web/lib/onnxjs/execution-plan.ts +++ b/js/web/lib/onnxjs/execution-plan.ts @@ -114,7 +114,7 @@ export class ExecutionPlan { // resolve downstream nodes const downstreamNodes = new Set(); - outputList.forEach((output, i) => { + outputList.forEach((_output, i) => { const j = thisOp.node.outputs[i]; for (const currentDownstreamNodeIndex of graphValues[j].to) { const currentDownstreamNode = graphNodes[currentDownstreamNodeIndex]; diff --git a/js/web/lib/onnxjs/instrument.ts b/js/web/lib/onnxjs/instrument.ts index 4c543cab157d7..4f865503d50ec 100644 --- a/js/web/lib/onnxjs/instrument.ts +++ b/js/web/lib/onnxjs/instrument.ts @@ -176,7 +176,7 @@ function createCategorizedLogger(category: string): Logger.CategorizedLogger { // NOTE: argument 'category' is put the last parameter beacause typescript // doesn't allow optional argument put in front of required argument. This // order is different from a usual logging API. -function logInternal(severity: Logger.Severity, content: string, stack: number, category?: string) { +function logInternal(severity: Logger.Severity, content: string, _stack: number, category?: string) { const config = LOGGER_CONFIG_MAP[category || ''] || LOGGER_CONFIG_MAP['']; if (SEVERITY_VALUE[severity] < SEVERITY_VALUE[config.minimalSeverity]) { return; diff --git a/js/web/lib/onnxjs/util.ts b/js/web/lib/onnxjs/util.ts index 0a76d75e79bbf..d697a8b3138cf 100644 --- a/js/web/lib/onnxjs/util.ts +++ b/js/web/lib/onnxjs/util.ts @@ -967,7 +967,7 @@ export class ReduceUtil { const dims = a.dims.slice(0); // if axes is not set, perform reduce on all axes if (axes.length === 0) { - dims.forEach((d, ind) => axes.push(ind)); + dims.forEach((_d, ind) => axes.push(ind)); } // get a temporary broadcastable output shape const outputDims = ReduceUtil.calcReduceShape(dims, axes, true); diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index a4d51e68b6a25..9f5dceb8f4726 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax'; +import {attention, parseAttentionAttributes} from './ops/attention'; import {biasAdd} from './ops/bias-add'; import {biasSplitGelu} from './ops/bias-split-gelu'; import * as binaryOps from './ops/binary-op'; @@ -16,6 +17,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm'; import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm'; import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm'; import {matMul} from './ops/matmul'; +import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; import {pad, parsePadAttributes} from './ops/pad'; import * as pool from './ops/pool'; import {range} from './ops/range'; @@ -46,6 +48,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Asinh', [unaryOps.asinh]], ['Atan', [unaryOps.atan]], ['Atanh', [unaryOps.atanh]], + ['Attention', [attention, parseAttentionAttributes]], // TODO: support new attributes for AveragePool-10 ['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]], ['BiasAdd', [biasAdd]], @@ -86,6 +89,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new // TODO: support new attributes for MaxPool-8 and MaxPool-10 ['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]], ['Mul', [binaryOps.mul]], + ['MultiHeadAttention', [multiHeadAttention, parseMultiHeadAttentionAttributes]], ['Neg', [unaryOps.neg]], ['Not', [unaryOps.not]], ['Pad', [pad, parsePadAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts index 2f8d072c7dd4f..b6c6853c8f222 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts @@ -27,11 +27,6 @@ export interface ArgMinMaxAttributes extends AttributeWithCacheKey { selectLastIndex: number; } -const createArgMinMaxAttributesFromInputs = - (inputs: readonly TensorView[], attributes: ArgMinMaxAttributes): ArgMinMaxAttributes => - createAttributeWithCacheKey( - {axis: attributes.axis, keepDims: attributes.keepDims, selectLastIndex: attributes.selectLastIndex}); - export const argMin = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => { validateInputs(context.inputs); const argMinMaxOp: ReduceOp = (input, output, axes) => { @@ -51,12 +46,10 @@ export const argMin = (context: ComputeContext, attributes: ArgMinMaxAttributes) ]; }; - const updatedAttributes: ArgMinMaxAttributes = - context.inputs.length === 1 ? attributes : createArgMinMaxAttributesFromInputs(context.inputs, attributes); context.compute( createReduceProgramInfo( - 'ArgMin', {hint: updatedAttributes.cacheKey}, [context.inputs[0]], argMinMaxOp, [updatedAttributes.axis], - DataType.int64, updatedAttributes.keepDims), + 'ArgMin', {hint: attributes.cacheKey}, [context.inputs[0]], argMinMaxOp, [attributes.axis], DataType.int64, + attributes.keepDims), {inputs: [0]}); }; @@ -79,12 +72,10 @@ export const argMax = (context: ComputeContext, attributes: ArgMinMaxAttributes) ]; }; - const updatedAttributes: ArgMinMaxAttributes = - context.inputs.length === 1 ? attributes : createArgMinMaxAttributesFromInputs(context.inputs, attributes); context.compute( createReduceProgramInfo( - 'argMax', {hint: updatedAttributes.cacheKey}, [context.inputs[0]], argMinMaxOp, [updatedAttributes.axis], - DataType.int64, updatedAttributes.keepDims), + 'argMax', {hint: attributes.cacheKey}, [context.inputs[0]], argMinMaxOp, [attributes.axis], DataType.int64, + attributes.keepDims), {inputs: [0]}); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts new file mode 100644 index 0000000000000..e1f2a47301bfb --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -0,0 +1,635 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType} from '../types'; + +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; + +export const enum AttentionQkvFormat { + unknown, // enum value not set, or depends on qkv projection implementation details + qkvBNSH, // for non-packed qkv, permuted + qkvBSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention + qkvBSN3H, // for TRT fused attention, qkv are packed + qkvBNSHqkvBS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) + qKvBSNHxBSN2H, // for TRT fused cross attention, kv are packed + qkvTNH, // for memory efficient attention, qkv are not packed, and paddings are removed. + qkvTN3H, // for TRT fused attention, qkv are packed and paddings are removed +} + +export const enum AttentionMaskType { + none, // No mask + mask1dKeySeqLen, // [batch_size], key sequence length + mask1dEndStart, // [2 * batch_size] with end positions and start positions + mask1DKeySeqLenStart, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0], + // ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ..., + // key_start[batch_size - 1], key_end[batch_size - 1]] + mask2dDummy, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask. + mask2dKeyPadding, // [batch_size, total_sequence_length] + mask3dAttention, // [batch_size, sequence_length, total_sequence_length] + mask4dMegatron, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length] + maskUnknown +} + +export interface AttentionParameters { + batchSize: number; + sequenceLength: number; + pastSequenceLength: number; + kvSequenceLength: number; + totalSequenceLength: number; + maxSequenceLength: number; + inputHiddenSize: number; + hiddenSize: number; + vHiddenSize: number; + headSize: number; + vHeadSize: number; + numHeads: number; + isUnidirectional: boolean; + pastPresentShareBuffer: boolean; + maskFilterValue: number; + maskType: AttentionMaskType; + scale: number; + broadcastResPosBias: boolean; + passPastInKv: boolean; + qkvFormat: AttentionQkvFormat; +} + +export interface AttentionAttrs { + numHeads: number; + isUnidirectional: number; + maskFilterValue: number; + scale: number; + doRotary: number; + qkvHiddenSizes: number[]; + pastPresentShareBuffer: boolean; +} + +const validateAttentionInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { + // Abbreviation and Meanings: + // B: batch_size + // S: sequence_length (input sequence length of query) + // P: past_sequence_length (past sequence length of key or value) + // L: kv_sequence_length (input sequence length of key or value) + // M: max_sequence_length + // T: total_sequence_length = past_sequence_length + kv_sequence_length + // N: num_heads + // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size + // H_v: v_head_size + // D_i: input hidden size + // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size + // D_v: v_hidden_size = num_heads * v_head_size + + // When past state is used, Q, K and V should have same hidden size (unless we split it into past_key and past_value). + + // Input shapes: + // input (Q/K/V) : (B, S, D_i) + // weights (Q/K/V) : (D_i, D + D + D_v) + // bias (Q/K/V) : (D + D + D_v) + // mask_index : see below + // past (K/V) : (2, B, N, P, H) or NULL + // relative_position_bias : (B, N, S, T) or NULL + + // For mask_index, the following shapes are supported: + // NULL, (B, 1), (1, 1) + // (B), (2 * B), (3 * B + 2) + // (B, T) + // (B, S, T) + // (B, 1, M, M) + // + // When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger + // than hidden dimension of Q, K and V. + + const input = inputs[0]; + const weights = inputs[1]; + const bias = inputs[2]; + const maskIndex = inputs[3]; + const past = inputs[4]; + const relativePositionBias = inputs[5]; + + if (past && relativePositionBias) { + throw new Error('Attention cannot have both past and relative_position_bias'); + } + + if (input.dims.length !== 3) { + throw new Error('Input "input" must have 3 dimensions'); + } + + const batchSize = input.dims[0]; + const sequenceLength = input.dims[1]; + const inputHiddenSize = input.dims[2]; + + if (bias.dims.length !== 1) { + throw new Error('Input "bias" is expected to have 1 dimensions'); + } + + if (weights.dims.length !== 2) { + throw new Error('Input "weights" is expected to have 2 dimensions'); + } + + if (weights.dims[0] !== inputHiddenSize) { + throw new Error('Input 1 dimension 0 should have same length as dimension 2 of input 0'); + } + + if (bias.dims[0] !== weights.dims[1]) { + throw new Error('Input "bias" dimension 0 should have same length as dimension 1 of input "weights"'); + } + + let qHiddenSize = bias.dims[0] / 3; + let kHiddenSize = qHiddenSize; + let vHiddenSize = kHiddenSize; + if (attributes.qkvHiddenSizes.length > 0) { + if (attributes.qkvHiddenSizes.length !== 3) { + throw new Error('qkv_hidden_sizes attribute should have 3 elements'); + } + for (const sz of attributes.qkvHiddenSizes) { + if (sz % attributes.numHeads !== 0) { + throw new Error('qkv_hidden_sizes should be divisible by num_heads'); + } + } + + qHiddenSize = attributes.qkvHiddenSizes[0]; + kHiddenSize = attributes.qkvHiddenSizes[1]; + vHiddenSize = attributes.qkvHiddenSizes[2]; + } + + const kvSequenceLength = sequenceLength; + + if (qHiddenSize !== kHiddenSize) { + throw new Error('qkv_hidden_sizes first element should be same as the second'); + } + + if (bias.dims[0] !== qHiddenSize + kHiddenSize + vHiddenSize) { + throw new Error('Input "bias" dimension 0 should have same length as sum of Q/K/V hidden sizes'); + } + + let pastSequenceLength = 0; + if (past) { + if (kHiddenSize !== vHiddenSize) { + throw new Error('Input "past" expect k_hidden_size == v_hidden_size'); + } + if (past.dims.length !== 5) { + throw new Error('Input "past" must have 5 dimensions'); + } + if (past.dims[0] !== 2) { + throw new Error('Input "past" first dimension must be 2'); + } + if (past.dims[1] !== batchSize) { + throw new Error('Input "past" second dimension must be batch_size'); + } + if (past.dims[2] !== attributes.numHeads) { + throw new Error('Input "past" third dimension must be num_heads'); + } + if (past.dims[4] !== kHiddenSize / attributes.numHeads) { + throw new Error('Input "past" fifth dimension must be k_hidden_size / num_heads'); + } + + if (!attributes.pastPresentShareBuffer) { + pastSequenceLength = past.dims[3]; + } + // TODO: handle past_seq_len + } + + const totalSequenceLength = kvSequenceLength + pastSequenceLength; + const maxSequenceLength = -1; + + const maskType = AttentionMaskType.none; + if (maskIndex) { + // maskType = AttentionMaskType.MASK_UNKNOWN; + // TODO: handle mask + throw new Error('Mask not supported'); + } + + if (past) { + throw new Error('past is not supported'); + } + if (relativePositionBias) { + throw new Error('relativePositionBias is not supported'); + } + + return { + batchSize, + sequenceLength, + pastSequenceLength, + kvSequenceLength, + totalSequenceLength, + maxSequenceLength, + inputHiddenSize, + hiddenSize: qHiddenSize, + vHiddenSize, + headSize: Math.floor(qHiddenSize / attributes.numHeads), + vHeadSize: Math.floor(vHiddenSize / attributes.numHeads), + numHeads: attributes.numHeads, + isUnidirectional: false, + pastPresentShareBuffer: false, + maskFilterValue: attributes.maskFilterValue, + maskType, + scale: attributes.scale, + broadcastResPosBias: false, + passPastInKv: false, + qkvFormat: AttentionQkvFormat.qkvBNSH, + }; +}; + +export const parseAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => + createAttributeWithCacheKey({...attributes}); + +export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView, n: number, d: number) => { + const components = getMaxComponents(d); + const inputHelper = outputVariable('x', input.dataType, input.dims, components); + + let threadMaxValue = 'threadMaxVector'; + if (components === 2) { + threadMaxValue = 'max(threadMaxVector.x, threadMaxVector.y)'; + } else if (components === 4) { + threadMaxValue = 'max(max(threadMaxVector.x, threadMaxVector.y), max(threadMaxVector.z, threadMaxVector.w))'; + } + const dataType = tensorTypeToWsglStorageType(input.dataType); + let WG = 64; + const dComp = d / components; + if (dComp < WG) { + WG = 1; + } else if (dComp / 8 < 64) { + WG = Math.ceil(dComp / 8); + } + const elementsPerWG = Math.ceil(d / components / WG); + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const dInv: ${dataType} = 1 / ${d}; + const dComp = ${d / components}; + var wgMax: array; + var wgSum: array; + + ${shaderHelper.declareVariables(inputHelper)} + @compute @workgroup_size(${WG}, 1, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_index) local_index : u32) { + let localOffset = local_index * ${elementsPerWG}; + let offset: u32 = workgroup_id.x * dComp + localOffset; + + var threadMaxVector = ${fillVector('f32', components, '-3.402823e+38f')}; + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + threadMaxVector = max(${castToF32(dataType, components, 'x[offset + i]')}, threadMaxVector); + } + wgMax[local_index] = ${threadMaxValue}; + workgroupBarrier(); + + var maxValue = -3.402823e+38f; + for (var i = 0u; i < ${WG}; i++) { + maxValue = max(wgMax[i], maxValue); + } + + var sumVector = ${fillVector('f32', components, '0')}; + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + sumVector += exp(${castToF32(dataType, components, 'x[offset + i]')} - maxValue); + } + wgSum[local_index] = ${sumVector('sumVector', components)}; + workgroupBarrier(); + + var sum: f32 = 0; + for (var i = 0u; i < ${WG}; i++) { + sum += wgSum[i]; + } + + if (sum == 0) { + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + x[offset + i] = ${fillVector(dataType, components, 'dInv')}; + } + } else { + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + let f32input = ${castToF32(dataType, components, 'x[offset + i]')}; + x[offset + i] = ${inputHelper.type.value}(exp(f32input - maxValue) / sum); + } + } + }`; + + context.compute( + { + name: 'AttentionProbsSoftmax', + shaderCache: {hint: `${d}`}, + getShaderSource, + getRunData: () => ({ + outputs: [], + dispatchGroup: {x: n}, + }), + }, + {inputs: [input], outputs: []}); +}; + +const computeAttentionProbs = + (context: ComputeContext, q: TensorView, key: TensorView, _bias: TensorView|undefined, + parameters: AttentionParameters, attributes: AttentionAttrs) => { + const probsShape = [ + parameters.batchSize, parameters.numHeads, parameters.sequenceLength, + parameters.kvSequenceLength + parameters.pastSequenceLength + ]; + // TODO: handle mask + + const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; + + const dataType = tensorTypeToWsglStorageType(q.dataType); + + const components = getMaxComponents(parameters.headSize); + const qInput = inputVariable('q', q.dataType, q.dims, components); + const kInput = inputVariable('key', key.dataType, key.dims, components); + const output = outputVariable('output', q.dataType, probsShape); + + const vectorizedHeadSize = parameters.headSize / components; + const M = parameters.sequenceLength; + const N = parameters.totalSequenceLength; + const K = vectorizedHeadSize; + + const TILE_SIZE = 12; + + const dispatch = { + x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE), + y: Math.ceil(parameters.sequenceLength / TILE_SIZE), + z: parameters.batchSize * parameters.numHeads + }; + + const inputs = [q, key]; + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const M: u32 = ${M}u; + const N: u32 = ${N}u; + const K: u32 = ${K}u; + const alpha: ${dataType} = ${alpha}; + const beta: ${dataType} = 1.0; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + var tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + + ${shaderHelper.declareVariables(qInput, kInput, output)} + + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; + + // x holds the N and y holds the M + let headIdx = workgroup_id.z; + let m = workgroup_id.y * TILE_SIZE; + let n = workgroup_id.x * TILE_SIZE; + let lm = m + local_id.y; + let ln = n + local_id.x; + + let qOffset = ${parameters.sequenceLength * vectorizedHeadSize} * headIdx + m * K; + let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * headIdx + n * K; + + var value = ${fillVector(dataType, components)}; + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m + local_id.y < M && w + local_id.x < K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * K + w + local_id.x]; + } + if (n + local_id.y < N && w + local_id.x < K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * K + w + local_id.x]; + } + workgroupBarrier(); + + for (var k: u32 = 0u; k ({ + outputs: [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: dispatch, + }), + getShaderSource, + }, + {inputs, outputs: [-1]})[0]; + + computeInPlaceSoftmax( + context, probs, parameters.batchSize * parameters.numHeads * parameters.sequenceLength, + parameters.totalSequenceLength); + + return probs; + }; + +const computeVxAttentionScore = + (context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => { + const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize]; + + const probsHelper = inputVariable('probs', probs.dataType, probs.dims); + const vHelper = inputVariable('v', v.dataType, v.dims); + const output = outputVariable('output', probs.dataType, outputShape); + + const dataType = tensorTypeToWsglStorageType(probs.dataType); + + const TILE_SIZE = 12; + const dispatch = { + x: Math.ceil(params.vHeadSize / TILE_SIZE), + y: Math.ceil(params.sequenceLength / TILE_SIZE), + z: params.batchSize * params.numHeads + }; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const M: u32 = ${params.sequenceLength}u; + const N: u32 = ${params.vHeadSize}u; + const K: u32 = ${params.totalSequenceLength}u; + const numHeads: u32 = ${params.numHeads}u; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileQ: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + var tileK: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + + ${shaderHelper.declareVariables(probsHelper, vHelper, output)} + + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; + + let headIdx = workgroup_id.z; + let m = workgroup_id.y * TILE_SIZE + local_id.y; + let n = workgroup_id.x * TILE_SIZE + local_id.x; + + let offsetA = headIdx * (M * K) + m * K; + let offsetB = headIdx * (N * K) + n; + + var value = ${dataType}(0); + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m < M && w + local_id.x < K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x]; + } + if (n < N && w + local_id.y < K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * N]; + } + workgroupBarrier(); + for (var k: u32 = 0u; k ({ + outputs: [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: dispatch, + }), + getShaderSource, + }, + {inputs: [probs, v], outputs: [0]})[0]; + }; + +export const applyAttention = + (context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined, + _past: TensorView|undefined, _pastKey: TensorView|undefined, _pastValue: TensorView|undefined, + relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => { + const probs = computeAttentionProbs(context, q, k, relativePositionBias, parameters, attributes); + + computeVxAttentionScore(context, probs, v, parameters); + }; + +const prepare = (context: ComputeContext, parameters: AttentionParameters) => { + const outputShape = [ + parameters.batchSize, + parameters.numHeads, + parameters.sequenceLength, + parameters.headSize, + ]; + + const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); + + const M = parameters.sequenceLength; + const K = parameters.inputHiddenSize; + const N = parameters.headSize; + + const TILE_SIZE = 12; + const dispatch = { + x: Math.ceil(parameters.headSize / TILE_SIZE), + y: Math.ceil(parameters.sequenceLength / TILE_SIZE), + z: parameters.batchSize * parameters.numHeads + }; + + const getShaderSource = () => ` + const M: u32 = ${M}u; + const K: u32 = ${K}u; + const N: u32 = ${N}u; + const numHeads: u32 = ${parameters.numHeads}; + const ldb = ${parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}u; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileInput: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightV: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + + @group(0) @binding(0) var input: array<${dataType}>; + @group(0) @binding(1) var weight: array<${dataType}>; + @group(0) @binding(2) var bias: array<${dataType}>; + @group(0) @binding(3) var outputQ: array<${dataType}>; + @group(0) @binding(4) var outputK: array<${dataType}>; + @group(0) @binding(5) var outputV: array<${dataType}>; + + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; + + let batchIndex = workgroup_id.z / ${parameters.numHeads}; + let headNumber = workgroup_id.z % ${parameters.numHeads}; + let m = workgroup_id.y * TILE_SIZE + local_id.y; + let n = workgroup_id.x * TILE_SIZE + local_id.x; + + let inputOffset = batchIndex * (M * K) + m * K; + let biasOffsetQ = headNumber * ${parameters.headSize}; + let biasOffsetK = ${parameters.hiddenSize} + biasOffsetQ; + let biasOffsetV = ${parameters.hiddenSize} + biasOffsetK; + + var valueQ = ${dataType}(0); + var valueK = ${dataType}(0); + var valueV = ${dataType}(0); + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m < M && w + local_id.x < K) { + tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x]; + } + if (n < N && w + local_id.y < K) { + let offset = n + (w + local_id.y) * ldb; + tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset]; + tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset]; + tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset]; + } + workgroupBarrier(); + for (var k: u32 = 0u; k ({ + outputs: [ + {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, + {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, + {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, + ], + dispatchGroup: dispatch, + }), + getShaderSource, + }, + {inputs, outputs: [-1, -1, -1]}); +}; + +export const attention = (context: ComputeContext, attributes: AttentionAttrs): void => { + const params = validateAttentionInputs(context.inputs, attributes); + + const [q, k, v] = prepare(context, params); + + return applyAttention( + context, q, k, v, context.inputs[4], undefined, undefined, undefined, context.inputs[5], params, attributes); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 38dc14f23682e..014d9d02f6f10 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -646,6 +646,8 @@ export const outputVariable = (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, false, components); +export type UniformsArrayType = Array<{name: string; type: string}>; + /** * A ShaderHelper is a helper class for generating WGSL code. */ @@ -697,6 +699,7 @@ export interface ShaderHelper { * A helper function to register one uniform. Can be called multiple times to register multiple uniforms. */ registerUniform(name: string, type: string): ShaderHelper; + registerUniforms(nameToTypeMap: UniformsArrayType): ShaderHelper; } class ShaderHelperImpl implements ShaderHelper { @@ -755,8 +758,13 @@ class ShaderHelperImpl implements ShaderHelper { return this; } + registerUniforms(additionalUniforms: UniformsArrayType): ShaderHelper { + this.uniforms = this.uniforms.concat(additionalUniforms); + return this; + } + private indicesHelpers: IndicesHelper[] = []; - private uniforms: Array<{name: string; type: string}> = []; + private uniforms: UniformsArrayType = []; private uniformDeclaration(): string { if (this.uniforms.length === 0) { return ''; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts index 357eb5c0b84ad..a233d37a79e65 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -254,7 +254,7 @@ const createEinsumProgramInfo = (inputs: readonly TensorView[], einsumEquation: ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} var outputIndices = ${output.offsetToIndices('global_idx')}; - ${inputVars.map((inputVar, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')} + ${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')} ${reduceOps.join('\n')}; ${output.setByOffset('global_idx', 'sum')}; }`; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts new file mode 100644 index 0000000000000..b7726a36bcaad --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts @@ -0,0 +1,335 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType} from '../types'; + +import {applyAttention, AttentionAttrs, AttentionMaskType, AttentionParameters, AttentionQkvFormat} from './attention'; +import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {createTransposeProgramInfo, TransposeAttributes} from './transpose'; + +const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { + const query = inputs[0]; + const key = inputs[1]; + const value = inputs[2]; + const bias = inputs[3]; + const keyPaddingMask = inputs[4]; + const relativePositionBias = inputs[5]; + const pastKey = inputs[6]; + const pastValue = inputs[7]; + + // Abbreviation and Meanings: + // B: batch_size + // S: sequence_length (input sequence length of query) + // P: past_sequence_length (past sequence length of key or value) + // L: kv_sequence_length (input sequence length of key or value) + // M: max_sequence_length + // T: total_sequence_length = past_sequence_length + kv_sequence_length + // N: num_heads + // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size + // H_v: v_head_size + // D_i: input hidden size + // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size + // D_v: v_hidden_size = num_heads * v_head_size + + // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None + // relative_position_bias : (B, 1, S, L) + // past_key : (B, N, S*, H) + // past_value : (B, N, S*, H) + // When no packing for q/k/v: + // query (Q) : (B, S, D) + // key (K) : (B, L, D) or (B, N, S*, H) + // value (V) : (B, L, D_v) or (B, N, S*, H) + // bias (Q/K/V) : (D + D + D_v) + // When packed kv is used: + // query (Q) : (B, S, D) + // key (K) : (B, L, N, 2, H) + // value (V) : None + // bias (Q/K/V) : None + // When packed qkv is used: + // query (Q) : (B, L, N, 3, H) or (B, S, 3*D) + // key (K) : None + // value (V) : None + // bias (Q/K/V) : None or (D + D + D_v) + + if (query.dims.length !== 3 && query.dims.length !== 5) { + throw new Error('Input query is expected to have 3 or 5 dimensions'); + } + + const dmmhaPacking = false; + const batchSize = query.dims[0]; + const sequenceLength = query.dims[1]; + const hiddenSize = query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : + attributes.numHeads * query.dims[4]; + let kvSequenceLength = sequenceLength; + + let pastSequenceLength = 0; + let maxSequenceLength = 0; + const headSize = Math.floor(hiddenSize / attributes.numHeads); + if (pastKey && pastValue) { + if (pastKey.dims.length !== 4) { + throw new Error('Input "past_key" is expected to have 4 dimensions'); + } + if (pastValue.dims.length !== 4) { + throw new Error('Input "past_value" is expected to have 4 dimensions'); + } + pastSequenceLength = pastKey.dims[2]; + maxSequenceLength = pastKey.dims[2]; + } else if (pastKey || pastValue) { + throw new Error('Input "past_key" and "past_value" shall be both present or both absent'); + } + + let qkvFormat: AttentionQkvFormat; + if (key) { + if (query.dims.length !== 3) { + throw new Error('Input "query" is expected to have 3 dimensions when key is given'); + } + if (key.dims.length < 3 || key.dims.length > 5) { + throw new Error('Input "key" is expected to have 3, 4, or 5 dimensions'); + } + if (query.dims[0] !== key.dims[0]) { + throw new Error('Input "query" and "key" shall have same dim 0 (batch size)'); + } + + if (key.dims.length === 3) { + if (key.dims[2] !== query.dims[2]) { + throw new Error('Input "query" and "key" shall have same dim 2 (hidden_size)'); + } + qkvFormat = AttentionQkvFormat.qkvBSNH; + kvSequenceLength = key.dims[1]; + } else if (key.dims.length === 5) { + if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) { + throw new Error('Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv'); + } + if (value) { + throw new Error('Expect "value" be none when "key" has packed kv format.'); + } + qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H; + kvSequenceLength = key.dims[1]; + } else { // key_dims.size() == 4 (cross-attention with past_key) + if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) { + throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); + } + + qkvFormat = AttentionQkvFormat.unknown; + kvSequenceLength = key.dims[2]; + } + } else { // packed QKV + if (query.dims.length !== 3 && query.dims.length !== 5) { + throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty'); + } + if (query.dims.length === 5 && (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3)) { + throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv'); + } + + qkvFormat = AttentionQkvFormat.qkvBSN3H; + } + + if (bias) { + if (bias.dims.length !== 1) { + throw new Error('Input "bias" is expected to have 1 dimension'); + } + + if (value) { + if (query.dims.length === 5 && query.dims[3] === 2) { + throw new Error('bias is not allowed for packed kv.'); + } + } + } + + let maskType: AttentionMaskType = AttentionMaskType.none; + if (keyPaddingMask) { + maskType = AttentionMaskType.maskUnknown; + const maskDims = keyPaddingMask.dims; + if (maskDims.length === 1) { + if (maskDims[0] === batchSize) { + maskType = AttentionMaskType.mask1dKeySeqLen; + } else if (maskDims[0] === 3 * batchSize + 2) { + maskType = AttentionMaskType.mask1DKeySeqLenStart; + } + } else if (maskDims.length === 2 && maskDims[0] === batchSize && maskDims[1] === kvSequenceLength) { + maskType = AttentionMaskType.mask2dKeyPadding; + } + if (maskType === AttentionMaskType.maskUnknown) { + throw new Error('Input "key_padding_mask" shape shall be (batch_size) or (batch_size, kv_sequence_length)'); + } + throw new Error('Mask not supported'); + } + + let passPastInKv = false; + let vHiddenSize = hiddenSize; + if (value) { + if (value.dims.length !== 3 && value.dims.length !== 4) { + throw new Error('Input "value" is expected to have 3 or 4 dimensions'); + } + + if (query.dims[0] !== value.dims[0]) { + throw new Error('Input "query" and "value" shall have same dim 0 (batch_size)'); + } + + if (value.dims.length === 3) { + if (kvSequenceLength !== value.dims[1]) { + throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)'); + } + vHiddenSize = value.dims[2]; + } else { + if (kvSequenceLength !== value.dims[2]) { + throw new Error('Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)'); + } + vHiddenSize = value.dims[1] * value.dims[3]; + passPastInKv = true; + } + } + + const totalSequenceLength = pastSequenceLength + kvSequenceLength; + const broadcastResPosBias = false; + // if (extraAddQk) { + // if (extraAddQk.dims[0] === 1) { + // broadcastResPosBias = true; + // } + // } + + if (keyPaddingMask) { + throw new Error('Key padding mask is not supported'); + } + if (relativePositionBias) { + throw new Error('extraAddQk is not supported'); + } + if (pastKey) { + throw new Error('pastKey is not supported'); + } + if (pastValue) { + throw new Error('pastValue is not supported'); + } + + return { + batchSize, + sequenceLength, + pastSequenceLength, + kvSequenceLength, + totalSequenceLength, + maxSequenceLength, + inputHiddenSize: 0, + hiddenSize, + vHiddenSize, + headSize, + vHeadSize: Math.floor(vHiddenSize / attributes.numHeads), + numHeads: attributes.numHeads, + isUnidirectional: false, + pastPresentShareBuffer: false, + maskFilterValue: attributes.maskFilterValue, + maskType, + scale: attributes.scale, + broadcastResPosBias, + passPastInKv, + qkvFormat, + }; +}; + + +export const parseMultiHeadAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => + createAttributeWithCacheKey({...attributes}); + +const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [0, 2, 1, 3]}); + +const addBiasTranspose = + (context: ComputeContext, qkv: TensorView, bias: TensorView, batchSize: number, sequenceLength: number, + hiddenSize: number, biasOffset: number) => { + const outputShape = [batchSize, sequenceLength, hiddenSize]; + const outputSize = ShapeUtil.size(outputShape); + + const dataType = tensorTypeToWsglStorageType(qkv.dataType); + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const biasOffset = ${biasOffset}u; + const hiddenSize = ${hiddenSize}u; + + @group(0) @binding(0) var qkv: array<${dataType}>; + @group(0) @binding(1) var bias: array<${dataType}>; + @group(0) @binding(2) var qkv_with_bias: array<${dataType}>; + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + let biasOffsetIdx = (global_idx % hiddenSize) + biasOffset; + + qkv_with_bias[global_idx] = qkv[global_idx] + bias[biasOffsetIdx]; + }`; + + return context.compute( + { + name: 'MultiHeadAttentionAddBias', + shaderCache: {hint: JSON.stringify({batchSize, sequenceLength, hiddenSize, biasOffset})}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: qkv.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + }), + getShaderSource, + }, + {inputs: [qkv, bias], outputs: [-1]})[0]; + }; + +const maybeTransposeToBNSHAndAddBias = + (context: ComputeContext, batchSize: number, numHeads: number, sequenceLength: number, headSize: number, + input: TensorView, bias?: TensorView, biasOffset?: number) => { + // const newDims = []; + + let reshapedInput = input; + if (!bias) { + if (input.dims.length === 3) { + reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]); + } + return context.compute( + createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), + {inputs: [reshapedInput], outputs: [-1]})[0]; + } else { + if (sequenceLength === 1) { + throw new Error('AddBiasReshape is not implemented. Please export your model with packed QKV or KV'); + } else { + reshapedInput = + addBiasTranspose(context, input, bias, batchSize, sequenceLength, numHeads * headSize, biasOffset!); + reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]); + return context.compute( + createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), + {inputs: [reshapedInput], outputs: [-1]})[0]; + } + } + }; + +export const multiHeadAttention = (context: ComputeContext, attributes: AttentionAttrs): void => { + const params = validateInputs(context.inputs, attributes); + + if (context.inputs[0].dims.length === 5) { + throw new Error('Packed QKV is not implemented'); + } + + if (context.inputs[1]?.dims.length === 5) { + throw new Error('Packed KV is not implemented'); + } + + // applyAttention expects BNSH inputs + const kvBNSH = context.inputs[1] && context.inputs[2] && context.inputs[1].dims.length === 4 && + context.inputs[2].dims.length === 4; + + const Q = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, context.inputs[0], + context.inputs[3], 0); + + if (kvBNSH) { + return applyAttention( + context, Q, context.inputs[1], context.inputs[2], context.inputs[4], undefined, undefined, undefined, + context.inputs[5], params, attributes); + } + + const K = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.kvSequenceLength, params.headSize, context.inputs[1], + context.inputs[3], params.hiddenSize); + + const V = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.kvSequenceLength, params.vHeadSize, context.inputs[2], + context.inputs[3], 2 * params.hiddenSize); + + applyAttention( + context, Q, K, V, context.inputs[4], undefined, context.inputs[6], context.inputs[7], context.inputs[5], params, + attributes); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index 180dab92a453a..18859e253aa02 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -36,8 +36,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => { }; const getPadConstant = - (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], - inputStrides: readonly number[], pads: number[], dataType: string, constantValue: number): string => { + (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[], + dataType: string, constantValue: number): string => { const inputRank = inputDims.length; let block = ''; @@ -66,8 +66,7 @@ const getPadConstant = }; const getPadReflect = - (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], - inputStrides: readonly number[], pads: number[]): string => { + (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => { const inputRank = inputDims.length; let block = ''; @@ -97,8 +96,7 @@ const getPadReflect = }; const getPadEdge = - (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], - inputStrides: readonly number[], pads: number[]): string => { + (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => { const inputRank = inputDims.length; let block = ''; @@ -124,8 +122,7 @@ const getPadEdge = }; const getPadWrap = - (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], - inputStrides: readonly number[], pads: number[]): string => { + (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => { const inputRank = inputDims.length; let block = ''; @@ -151,18 +148,17 @@ const getPadWrap = }; const getPadSnippet = - (output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[], - inputStrides: readonly number[], attributes: PadAttributes, dataType: string): string => { + (output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], attributes: PadAttributes, + dataType: string): string => { switch (attributes.mode) { case 0: - return getPadConstant( - output, outputDims, inputDims, inputStrides, attributes.pads, dataType, attributes.value); + return getPadConstant(output, inputDims, inputStrides, attributes.pads, dataType, attributes.value); case 1: - return getPadReflect(output, outputDims, inputDims, inputStrides, attributes.pads); + return getPadReflect(output, inputDims, inputStrides, attributes.pads); case 2: - return getPadEdge(output, outputDims, inputDims, inputStrides, attributes.pads); + return getPadEdge(output, inputDims, inputStrides, attributes.pads); case 3: - return getPadWrap(output, outputDims, inputDims, inputStrides, attributes.pads); + return getPadWrap(output, inputDims, inputStrides, attributes.pads); default: throw new Error('Invalid mode'); } @@ -179,7 +175,7 @@ const generatePadCode = const output = outputVariable('output', inputs[0].dataType, outputDims); const input = inputVariable('x', inputs[0].dataType, inputDims); - const padSnippet = getPadSnippet(output, outputDims, inputDims, inputStrides, attributes, dataType); + const padSnippet = getPadSnippet(output, inputDims, inputStrides, attributes, dataType); const padCode = ` ${shaderHelper.declareVariables(input, output)} ${shaderHelper.mainStart()} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts index 54a7414360c4e..1365d1e9a12a4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts @@ -199,7 +199,7 @@ const reduceCommon = let updatedAxes = updatedAttributes.axes; if (updatedAxes.length === 0 && !updatedAttributes.noopWithEmptyAxes) { - updatedAxes = context.inputs[0].dims.map((s, i) => i); + updatedAxes = context.inputs[0].dims.map((_dim, i) => i); } const normalizeAxes = ShapeUtil.normalizeAxes(updatedAxes, context.inputs[0].dims.length); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index 07cfefb8f191b..9869561a36251 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -218,32 +218,30 @@ const initOutputShape = return outputShape; }; -const adjustOutputShape = - (inputShape: readonly number[], outputShape: readonly number[], scales: number[], attributes: ResizeAttributes): - number[] => { - const scaleInPolicy = (() => { - switch (attributes.keepAspectRatioPolicy) { - case 'not_larger': - return attributes.axes.length > 0 ? Math.min(...attributes.axes.map(i => scales[i]), Number.MAX_VALUE) : - Math.min(...scales, Number.MAX_VALUE); - case 'not_smaller': - return attributes.axes.length > 0 ? Math.max(...attributes.axes.map(i => scales[i]), Number.MIN_VALUE) : - Math.max(...scales, Number.MIN_VALUE); - default: - throw new Error(`Keep aspect ratio policy ${attributes.keepAspectRatioPolicy} is not supported`); - } - })(); - scales.fill(1.0, 0, scales.length); - const adjustedOutputShape = inputShape.slice(); - if (attributes.axes.length > 0) { - attributes.axes.forEach((v) => scales[v] = scaleInPolicy); - attributes.axes.forEach((v) => adjustedOutputShape[v] = Math.round(inputShape[v] * scales[v])); - } else { - scales.fill(scaleInPolicy, 0, scales.length); - adjustedOutputShape.forEach((v, i) => adjustedOutputShape[i] = Math.round(v * scales[i])); - } - return adjustedOutputShape; - }; +const adjustOutputShape = (inputShape: readonly number[], scales: number[], attributes: ResizeAttributes): number[] => { + const scaleInPolicy = (() => { + switch (attributes.keepAspectRatioPolicy) { + case 'not_larger': + return attributes.axes.length > 0 ? Math.min(...attributes.axes.map(i => scales[i]), Number.MAX_VALUE) : + Math.min(...scales, Number.MAX_VALUE); + case 'not_smaller': + return attributes.axes.length > 0 ? Math.max(...attributes.axes.map(i => scales[i]), Number.MIN_VALUE) : + Math.max(...scales, Number.MIN_VALUE); + default: + throw new Error(`Keep aspect ratio policy ${attributes.keepAspectRatioPolicy} is not supported`); + } + })(); + scales.fill(1.0, 0, scales.length); + const adjustedOutputShape = inputShape.slice(); + if (attributes.axes.length > 0) { + attributes.axes.forEach((v) => scales[v] = scaleInPolicy); + attributes.axes.forEach((v) => adjustedOutputShape[v] = Math.round(inputShape[v] * scales[v])); + } else { + scales.fill(scaleInPolicy, 0, scales.length); + adjustedOutputShape.forEach((v, i) => adjustedOutputShape[i] = Math.round(v * scales[i])); + } + return adjustedOutputShape; +}; const calculateOriginalIndicesFromOutputIndices = (output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[], @@ -314,8 +312,8 @@ const checkInputIndices = (input: IndicesHelper, inputShape: readonly number[]): }`; const bilinearInterpolation = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], - scales: readonly number[], useExtrapolation: boolean, extrapolationValue: number): string => { + (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], scales: readonly number[], + useExtrapolation: boolean, extrapolationValue: number): string => { const [batchIdx, heightIdx, widthIdx, channelIdx] = inputShape.length === 2 ? [-1, 0, 1, -1] : (scales[1] === 1.0 ? [0, 2, 3, 1] : [0, 1, 2, 3]); return ` @@ -445,7 +443,7 @@ const createResizeProgramInfo = if (scalesInput.length === 0) { scales = inputShape.map((value, index) => value === 0 ? 1.0 : outputShape[index] / value); if (attributes.keepAspectRatioPolicy !== 'stretch') { - outputShape = adjustOutputShape(inputShape, outputShape, scales, attributes); + outputShape = adjustOutputShape(inputShape, scales, attributes); } } const output = outputVariable('output', inputTensor.dataType, outputShape); @@ -471,7 +469,7 @@ const createResizeProgramInfo = ${calculateOriginalIndicesFromOutputIndices(output, inputShape, outputShape, scales, roi)}; ${ bilinearInterpolation( - input, output, inputShape, outputShape, scales, useExtrapolation, attributes.extrapolationValue)}; + input, output, inputShape, scales, useExtrapolation, attributes.extrapolationValue)}; `; case 'cubic': return ` diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index d607351f69b74..7458579bf4340 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -5,9 +5,9 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, TensorInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; export interface SliceAttributes extends AttributeWithCacheKey { readonly starts: number[]; @@ -77,17 +77,26 @@ const fixStartEndValues = }; const calculateInputIndicesImpl = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[]): - string => `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { + (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], + enableInputShapeUniforms: boolean): string => + `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { var inputIndices: ${input.type.indices}; var carry = 0u; for (var i = ${inputShape.length}; i >= 0; i--) { + let input_shape_i = ${ + enableInputShapeUniforms ? `uniforms.input_shape${inputShape.length > 1 ? '[i]' : ''}` : 'inputShape[i]'}; + let steps_i = ${ + enableInputShapeUniforms ? `uniforms.steps${inputShape.length > 1 ? '[i]' : ''}` : 'steps[i]'}; + let signs_i = ${ + enableInputShapeUniforms ? `uniforms.signs${inputShape.length > 1 ? '[i]' : ''}` : 'signs[i]'}; + let starts_i = ${ + enableInputShapeUniforms ? `uniforms.starts${inputShape.length > 1 ? '[i]' : ''}` : 'starts[i]'}; var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; - var inputIndex = outputIndex * steps[i] + starts[i] + carry; - carry = inputIndex / inputShape[i]; - inputIndex = inputIndex % inputShape[i]; - if (signs[i] < 0) { - inputIndex = inputShape[i] - inputIndex - 1u + starts[i]; + var inputIndex = outputIndex * steps_i + starts_i + carry; + carry = inputIndex / input_shape_i; + inputIndex = inputIndex % input_shape_i; + if (signs_i < 0) { + inputIndex = input_shape_i - inputIndex - 1u + starts_i; } ${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'} = inputIndex; } @@ -110,6 +119,10 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice const ends = attributes.ends.map((end, i) => fixStartEndValues(end, i, inputShape, axes, steps)); + if (axes.length !== starts.length || axes.length !== ends.length) { + throw new Error('start, ends and axes should have the same number of elements'); + } + if (axes.length !== inputShape.length) { for (let i = 0; i < inputShape.length; ++i) { if (!axes.includes(i)) { @@ -131,40 +144,66 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice array[i] = -step; } }); + // Output rank is expected to be less than or equal to the input rank. + const enableShapeUniforms = enableShapesUniforms(inputs[0].dims.length); + const inputShapeOrRank = enableShapeUniforms ? inputs[0].dims.length : inputs[0].dims; const outputShape = inputShape.slice(0); axes.forEach((axis, _) => { outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]); }); + const outputShapeOrRank = enableShapeUniforms ? outputShape.length : outputShape; const outputTensorInfo: TensorInfo = {dims: outputShape, dataType: inputs[0].dataType}; - const output = outputVariable('output', inputs[0].dataType, outputShape); - const input = inputVariable('input', inputs[0].dataType, inputShape); + const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank); + const input = inputVariable('input', inputs[0].dataType, inputShapeOrRank); const outputSize = ShapeUtil.size(outputShape); + const programUniforms: ProgramUniform[] = []; + const uniforms: UniformsArrayType = []; + if (enableShapeUniforms) { + uniforms.push({name: 'starts', type: starts.length > 1 ? `vec${starts.length}` : 'u32'}); + uniforms.push({name: 'signs', type: signs.length > 1 ? `vec${signs.length}` : 'i32'}); + uniforms.push({name: 'steps', type: steps.length > 1 ? `vec${steps.length}` : 'u32'}); + programUniforms.push({type: 'uint32', data: starts}); + programUniforms.push({type: 'int32', data: signs}); + programUniforms.push({type: 'uint32', data: steps}); + } + uniforms.push({name: 'outputSize', type: 'u32'}); + programUniforms.push({type: 'uint32', data: outputSize}); + if (enableShapeUniforms) { + programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); + programUniforms.push(...createTensorShapeVariables(outputShape)); + } const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(input, output)} - const signs = array(${signs.map(i => `${i}i`).join(',')}); - const starts = array(${starts.map(i => `${i}u`).join(',')}); - const ends = array(${ends.map(i => `${i}u`).join(',')}); - const steps = array(${steps.map(i => `${i}u`).join(',')}); - const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); - - ${calculateInputIndicesImpl(input, output, inputShape, outputShape)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)} + ${enableShapeUniforms ? '' : [ + `const signs = array(${signs.map(i => `${i}i`).join(',')});`, + `const starts = array(${starts.map(i => `${i}u`).join(',')});`, + `const steps = array(${steps.map(i => `${i}u`).join(',')});`, + `const inputShape = array(${inputShape.map(i => `${i}u`).join(',')});` + ].join('\n')} + + ${calculateInputIndicesImpl(input, output, inputShape, outputShape, enableShapeUniforms)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} let outputIndices = ${output.offsetToIndices('global_idx')}; let inputIndices = calculateInputIndices(outputIndices); ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))} }`; return { name: 'Slice', - shaderCache: {hint: `${attributes.cacheKey}|${inputs[4]?.dims ?? ''}`}, + shaderCache: { + hint: enableShapeUniforms ? `${signs.length}_${starts.length}_${steps.length}` : + `${attributes.cacheKey} | ${inputs[4]?.dims ?? ''}`, + inputDependencies: [enableShapeUniforms ? 'rank' : 'dims'] + }, getShaderSource, getRunData: () => ({ outputs: [outputTensorInfo], dispatchGroup: {x: Math.ceil(inputSize / 64 /* workgroup size */)}, + programUniforms }) }; }; diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts index 1f4595818e5c0..1cb6d9e391e4f 100644 --- a/js/web/lib/wasm/proxy-worker/main.ts +++ b/js/web/lib/wasm/proxy-worker/main.ts @@ -3,7 +3,39 @@ /// -import {OrtWasmMessage} from '../proxy-messages'; +// +// * type hack for "HTMLImageElement" +// +// in typescript, the type of "HTMLImageElement" is defined in lib.dom.d.ts, which is conflict with lib.webworker.d.ts. +// when we use webworker, the lib.webworker.d.ts will be used, which does not have HTMLImageElement defined. +// +// we will get the following errors complaining that HTMLImageElement is not defined: +// +// ==================================================================================================================== +// +// ../common/dist/cjs/tensor-factory.d.ts:187:29 - error TS2552: Cannot find name 'HTMLImageElement'. Did you mean +// 'HTMLLIElement'? +// +// 187 fromImage(imageElement: HTMLImageElement, options?: TensorFromImageElementOptions): +// Promise | TypedTensor<'uint8'>>; +// ~~~~~~~~~~~~~~~~ +// +// node_modules/@webgpu/types/dist/index.d.ts:83:7 - error TS2552: Cannot find name 'HTMLImageElement'. Did you mean +// 'HTMLLIElement'? +// +// 83 | HTMLImageElement +// ~~~~~~~~~~~~~~~~ +// +// ==================================================================================================================== +// +// `HTMLImageElement` is only used in type declaration and not in real code. So we define it as `unknown` here to +// bypass the type check. +// +declare global { + type HTMLImageElement = unknown; +} + +import {OrtWasmMessage, SerializableTensorMetadata} from '../proxy-messages'; import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl'; import {initializeWebAssembly} from '../wasm-factory'; @@ -11,7 +43,7 @@ self.onmessage = (ev: MessageEvent): void => { switch (ev.data.type) { case 'init-wasm': try { - initializeWebAssembly(ev.data.in) + initializeWebAssembly(ev.data.in!) .then( () => postMessage({type: 'init-wasm'} as OrtWasmMessage), err => postMessage({type: 'init-wasm', err} as OrtWasmMessage)); @@ -21,10 +53,10 @@ self.onmessage = (ev: MessageEvent): void => { break; case 'init-ort': try { - initRuntime(ev.data.in).then(() => postMessage({type: 'init-ort'} as OrtWasmMessage), err => postMessage({ - type: 'init-ort', - err - } as OrtWasmMessage)); + initRuntime(ev.data.in!).then(() => postMessage({type: 'init-ort'} as OrtWasmMessage), err => postMessage({ + type: 'init-ort', + err + } as OrtWasmMessage)); } catch (err) { postMessage({type: 'init-ort', err} as OrtWasmMessage); } @@ -58,8 +90,7 @@ self.onmessage = (ev: MessageEvent): void => { break; case 'release': try { - const handler = ev.data.in!; - releaseSession(handler); + releaseSession(ev.data.in!); postMessage({type: 'release'} as OrtWasmMessage); } catch (err) { postMessage({type: 'release', err} as OrtWasmMessage); @@ -68,10 +99,16 @@ self.onmessage = (ev: MessageEvent): void => { case 'run': try { const {sessionId, inputIndices, inputs, outputIndices, options} = ev.data.in!; - run(sessionId, inputIndices, inputs, outputIndices, options) + run(sessionId, inputIndices, inputs, outputIndices, new Array(outputIndices.length).fill(null), options) .then( outputs => { - postMessage({type: 'run', out: outputs} as OrtWasmMessage, extractTransferableBuffers(outputs)); + if (outputs.some(o => o[3] !== 'cpu')) { + postMessage({type: 'run', err: 'Proxy does not support non-cpu tensor location.'}); + } else { + postMessage( + {type: 'run', out: outputs} as OrtWasmMessage, + extractTransferableBuffers(outputs as SerializableTensorMetadata[])); + } }, err => { postMessage({type: 'run', err} as OrtWasmMessage); diff --git a/js/web/lib/wasm/proxy-worker/tsconfig.json b/js/web/lib/wasm/proxy-worker/tsconfig.json new file mode 100644 index 0000000000000..ec1044612a569 --- /dev/null +++ b/js/web/lib/wasm/proxy-worker/tsconfig.json @@ -0,0 +1,8 @@ +{ + "extends": "../../../tsconfig.json", + "compilerOptions": { + "lib": ["WebWorker"] + }, + "include": ["main.ts", "../../build-def.d.ts"], + "exclude": [] +} diff --git a/js/web/package.json b/js/web/package.json index 898b0ba3ec5f1..9b4531d7766fe 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -24,7 +24,7 @@ "build:doc": "node ./script/generate-webgl-operator-md && node ./script/generate-webgpu-operator-md", "pull:wasm": "node ./script/pull-prebuilt-wasm-artifacts", "test:e2e": "node ./test/e2e/run", - "prebuild": "tsc -p . --noEmit", + "prebuild": "tsc -p . --noEmit && tsc -p lib/wasm/proxy-worker --noEmit", "build": "node ./script/build", "test": "tsc --build ../scripts && node ../scripts/prepare-onnx-node-tests && node ./script/test-runner-cli", "prepack": "node ./script/build && node ./script/prepack" diff --git a/js/web/script/generate-webgpu-operator-md.ts b/js/web/script/generate-webgpu-operator-md.ts index 7408f17004f5e..eab8175a941bd 100644 --- a/js/web/script/generate-webgpu-operator-md.ts +++ b/js/web/script/generate-webgpu-operator-md.ts @@ -16,6 +16,8 @@ const COMMENTS: Record = { 'Reshape': 'no GPU kernel', 'Shape': 'no GPU kernel; an ORT warning is generated - need to fix', 'Resize': 'CoordinateTransformMode align_corners is not supported with downsampling', + 'Attention': 'need implementing mask and past/present', + 'MultiHeadAttention': 'need implementing mask and past/present', }; /* eslint-disable max-len */ diff --git a/js/web/script/tsconfig.json b/js/web/script/tsconfig.json index d7dc4cbe2d291..23b1fc96eb558 100644 --- a/js/web/script/tsconfig.json +++ b/js/web/script/tsconfig.json @@ -1,7 +1,6 @@ { "extends": "../../tsconfig.tools.json", "compilerOptions": { - "sourceMap": true, - "noUnusedParameters": false + "sourceMap": true } } diff --git a/js/web/test/data/ops/attention.jsonc b/js/web/test/data/ops/attention.jsonc new file mode 100644 index 0000000000000..bd4483027cc25 --- /dev/null +++ b/js/web/test/data/ops/attention.jsonc @@ -0,0 +1,557 @@ +[ + { + "name": "Attention Basic", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [4, 3], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [213, 213], + "dims": [1, 2, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic Batch 2 with 2 heads", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16 + ], + "dims": [2, 2, 8], + "type": "float32" + }, + { + "data": [ + 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 + ], + "dims": [8, 6], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [6], + "type": "float32" + } + ], + "outputs": [ + { + "data": [320, 321, 320, 321, 320, 321, 320, 321], + "dims": [2, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863], + "dims": [1, 3, 2], + "type": "float32" + }, + { + "data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1.1103, -1.6898, -0.989], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-1.328187108039856, -1.297916054725647, -0.8599594831466675], + "dims": [1, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic one head, batch 2", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3, 2], + "type": "float32" + }, + { + "data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1.1103, -1.6898, -0.989], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987 + ], + "dims": [2, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 2 head, batch 1", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3, 2], + "type": "float32" + }, + { + "data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643], + "dims": [2, 6], + "type": "float32" + }, + { + "data": [1.1103, -1.6898, -0.989, -0.989, 1.1103, -1.6898], + "dims": [6], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.8701779842376709, -2.6158859729766846, 0.8710794448852539, -2.5763747692108154, 0.9005484580993652, + -2.182751178741455, 2.1661579608917236, -2.1045265197753906, 1.6716957092285156, -1.797281265258789, + 1.7134947776794434, -1.765358328819275 + ], + "dims": [2, 3, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 2", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987 + ], + "dims": [2, 3, 5], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236 + ], + "dims": [5, 15], + "type": "float32" + }, + { + "data": [ + 1.1103, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, -1.6898, -0.989, -1.9029953479766846, 0.8710794448852539, + -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, 1.7134947776794434 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.6956915855407715, -2.8863370418548584, 1.3899128437042236, 1.6789076328277588, -1.4083852767944336, + -1.7009180784225464, -3.1053788661956787, 3.5959298610687256, 1.1027096509933472, -0.009643087163567543, + -1.694351315498352, -2.9284396171569824, 1.734721302986145, 2.0606398582458496, -0.2571452260017395, + 3.671973943710327, -5.285338401794434, -6.833454132080078, 1.7506506443023682, -2.262148380279541, + 2.5110034942626953, 1.440049171447754, -0.9423203468322754, 1.7506506443023682, -1.86212158203125, + -0.5036701560020447, -5.732386589050293, -1.5674757957458496, 1.7506510019302368, -2.264472246170044 + ], + "dims": [2, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 1", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846 + ], + "dims": [1, 3, 5], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236 + ], + "dims": [5, 15], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [1, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987 + ], + "dims": [3, 3, 5], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236 + ], + "dims": [5, 15], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326, + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326, + 3.7965505123138428, -2.3799397945404053, -3.9530906677246094, 0.5844926834106445, -2.9756431579589844, + 2.448162794113159, 4.34546422958374, 1.9380426406860352, 0.5870105624198914, -2.7368364334106445, + -0.4769568145275116, 4.255186557769775, -3.9529950618743896, 0.6987408995628357, -2.9756433963775635 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, + 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, + 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987 + ], + "dims": [3, 3, 10], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, + 0.2303, 0.4617, 1.44, -2.22, 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, + -1.9054111242294312, 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236 + ], + "dims": [10, 15], + "type": "float32" + }, + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -8.01101303100586, -5.782258987426758, 6.016238689422607, 0.26747000217437744, -6.992541313171387, + -8.011263847351074, -5.782248020172119, 5.366001129150391, 0.26747000217437744, -6.99449348449707, + -8.011263847351074, -5.782265663146973, 6.016238689422607, 0.26747000217437744, -6.992537021636963, + -6.102723598480225, -7.28973388671875, -4.578637599945068, 7.2203369140625, -6.028444766998291, + -6.102705478668213, -7.2897748947143555, -3.7882626056671143, 5.393260478973389, -5.754333972930908, + -1.3616288900375366, -7.289827823638916, -6.341128349304199, 6.329389572143555, -5.751791954040527, + -2.3945987224578857, -14.532954216003418, 3.969801902770996, 12.744998931884766, -11.1966552734375, + -2.4002532958984375, -14.538958549499512, -6.684961318969727, 12.476543426513672, -9.24352741241455, + -4.787771701812744, -8.640848159790039, 3.969801902770996, -0.6471102833747864, -11.1966552734375 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 1 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, + 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, + 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987 + ], + "dims": [3, 3, 10], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, + 0.2303, 0.4617, 1.44, -2.22, 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, + -1.9054111242294312, 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236 + ], + "dims": [10, 15], + "type": "float32" + }, + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + 1.3541864156723022, -7.813620090484619, -6.758509635925293, 7.597365856170654, -13.926229476928711, + -1.322464108467102, -7.297357559204102, -0.05962071940302849, 6.347561836242676, -5.869992256164551, + -1.3616288900375366, -7.28973388671875, 0.0386197566986084, 6.329389572143555, -5.751791954040527, + -2.400698661804199, -14.538958549499512, -7.898950576782227, 12.744998931884766, -11.1966552734375, + -2.400698661804199, -14.538958549499512, -7.898950576782227, 12.744998931884766, -11.1966552734375, + 1.021930456161499, -2.373898983001709, 3.8501391410827637, -0.6108309626579285, -9.256340980529785 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/multi-head-attention.jsonc b/js/web/test/data/ops/multi-head-attention.jsonc new file mode 100644 index 0000000000000..05687bd482e24 --- /dev/null +++ b/js/web/test/data/ops/multi-head-attention.jsonc @@ -0,0 +1,194 @@ +[ + { + "name": "MultiHeadAttention Basic, one head", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 4.973228454589844, 5.973228454589844, 6.973228454589844, 7.973228454589844, 4.999990940093994, + 5.999990940093994, 6.999990940093994, 7.999990940093994 + ], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 4.571832656860352, 5.571832656860352, 6.971858501434326, 7.971858501434326, 4.998325824737549, + 5.998325824737549, 6.999900817871094, 7.999900817871094 + ], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic with bias", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4], + "dims": [12], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 5.943336009979248, 7.94333553314209, 9.999799728393555, 11.999798774719238, 5.9997992515563965, + 7.9997992515563965, 10, 11.999999046325684 + ], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention two heads", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 8.99963665008545, 9.99963665008545, 10.99963665008545, 11.999635696411133, 13, 14, 15, 16, 9, 10, 11, 12, + 13, 14, 15, 16 + ], + "dims": [1, 2, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention two heads", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[1]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 1, 8], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 1, 8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 8], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/slice.jsonc b/js/web/test/data/ops/slice.jsonc index 9c90817a80c36..beef154a29932 100644 --- a/js/web/test/data/ops/slice.jsonc +++ b/js/web/test/data/ops/slice.jsonc @@ -21,6 +21,29 @@ } ] }, + { + "name": "Slice float32 with input[0] dim > 4", + "operator": "Slice", + "attributes": [], + "cases": [ + { + "name": "T[1, 1, 1, 1, 5] T[1] T[1] T[1] (float32)", + "inputs": [ + { + "data": [ + 0.3964604139328003, -0.8916832804679871, -1.6578896045684814, 1.960708737373352, 1.181204915046692 + ], + "dims": [1, 1, 1, 1, 5], + "type": "float32" + }, + { "data": [3], "dims": [1], "type": "int64" }, + { "data": [4], "dims": [1], "type": "int64" }, + { "data": [4], "dims": [1], "type": "int64" } + ], + "outputs": [{ "data": [1.960708737373352], "dims": [1, 1, 1, 1, 1], "type": "float32" }] + } + ] + }, { "name": "Slice int32", "operator": "Slice", diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index c80f0b04a9abc..37aa9394c7f96 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1336,6 +1336,7 @@ "add_int32.jsonc", //"and.jsonc", "asin.jsonc", + "attention.jsonc", "bias-add.jsonc", "bias-split-gelu.jsonc", "ceil.jsonc", @@ -1362,6 +1363,7 @@ "matmul-broadcast.jsonc", "mul.jsonc", "mul_int32.jsonc", + "multi-head-attention.jsonc", //"neg.jsonc", "neg-int32.jsonc", "not.jsonc", diff --git a/js/web/test/unittests/backends/webgl/test-conv-new.ts b/js/web/test/unittests/backends/webgl/test-conv-new.ts index 0fddddf58181c..8c186b9b36451 100644 --- a/js/web/test/unittests/backends/webgl/test-conv-new.ts +++ b/js/web/test/unittests/backends/webgl/test-conv-new.ts @@ -14,7 +14,7 @@ import {conv2d} from './test-conv-utils'; function createRandomArray(size: number): Float32Array { const randomTable = [0, 3, 6, 9, 2, 5, 8, 1, 4, 7]; return new Float32Array( - Array.from({length: size}, (v, k) => randomTable[k % 10] * 0.1 + randomTable[Math.trunc(k / 10) % 10] * 0.01)); + Array.from({length: size}, (_v, k) => randomTable[k % 10] * 0.1 + randomTable[Math.trunc(k / 10) % 10] * 0.01)); } interface TestData { inputShape: number[]; diff --git a/js/web/test/unittests/backends/webgl/test-pack-unpack.ts b/js/web/test/unittests/backends/webgl/test-pack-unpack.ts index 0b70144733227..61c21d4b689fb 100644 --- a/js/web/test/unittests/backends/webgl/test-pack-unpack.ts +++ b/js/web/test/unittests/backends/webgl/test-pack-unpack.ts @@ -291,7 +291,7 @@ describe('#UnitTest# - unpack - Tensor unpack', () => { webglInferenceHandler.session.textureManager.glContext.checkError(); const webglTexture = createTextureFromArray( webglInferenceHandler.session.textureManager.glContext, testData.rawData ? testData.rawData : inputData, - gl.RGBA, inputTextureShape[0], inputTextureShape[1]); + inputTextureShape[0], inputTextureShape[1]); webglInferenceHandler.session.textureManager.glContext.checkError(); const packedShape = inputTextureShape; const textureData = { diff --git a/js/web/test/unittests/backends/webgl/test-utils.ts b/js/web/test/unittests/backends/webgl/test-utils.ts index acb3f0002ce2f..092d63cd2ade4 100644 --- a/js/web/test/unittests/backends/webgl/test-utils.ts +++ b/js/web/test/unittests/backends/webgl/test-utils.ts @@ -4,7 +4,7 @@ import {WebGLContext} from '../../../../lib/onnxjs/backends/webgl/webgl-context'; export function createAscendingArray(size: number): Float32Array { - return new Float32Array(Array.from({length: size}, (v, i) => (i + 1))); + return new Float32Array(Array.from({length: size}, (_v, i) => (i + 1))); } // Returns an array by injecting 3 zeros after every element in the input array to be used for creating unpacked @@ -19,7 +19,7 @@ export function generateArrayForUnpackedTexture(input: Float32Array): Float32Arr // create a webgl texture and fill it with the array content export function createTextureFromArray( - glContext: WebGLContext, dataArray: Float32Array, type: GLenum, width: number, height: number): WebGLTexture { + glContext: WebGLContext, dataArray: Float32Array, width: number, height: number): WebGLTexture { const gl = glContext.gl; // create the texture diff --git a/js/web/tsconfig.json b/js/web/tsconfig.json index 0f35367db1d5a..d60d746e9328d 100644 --- a/js/web/tsconfig.json +++ b/js/web/tsconfig.json @@ -3,7 +3,6 @@ "compilerOptions": { "downlevelIteration": true, "declaration": true, - "noUnusedParameters": false, "typeRoots": ["./node_modules/@webgpu/types", "./node_modules/@types", "../node_modules/@types"] }, "include": ["lib", "test"], diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 0ed7d887fc5e5..57219c50f39aa 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -61,7 +61,6 @@ from onnxruntime.capi.onnxruntime_inference_collection import OrtDevice # noqa: F401 from onnxruntime.capi.onnxruntime_inference_collection import OrtValue # noqa: F401 from onnxruntime.capi.onnxruntime_inference_collection import SparseTensor # noqa: F401 -from onnxruntime.capi.training import * # noqa: F403 # TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end try: # noqa: SIM105 diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index b693b58c7c40a..a7f83469a768d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -96,9 +96,9 @@ struct GroupQueryAttentionParameters { int kv_num_heads; int num_splits; // number of splits for splitkv bool is_unidirectional; // causal + int local_window_size; bool kv_share_buffer; - bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor - bool left_padding; // copies last token to last index if true + bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor float scale; AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index 4a266af789250..47f462d75fcc4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -63,6 +63,16 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { const int head_size = parameters.head_size; const int position_ids_format = parameters.position_ids_format; const int half_head_size = head_size / 2; + // Default input tensor shape is [batch, seq_len, hidden_size] + int head_stride = head_size; + int seq_stride = num_heads * head_stride; + int batch_stride = sequence_length * seq_stride; + if (parameters.transposed) { + // Transposed input tensor shape is [batch, num_heads, seq_len, head_size] + seq_stride = head_size; + head_stride = sequence_length * seq_stride; + batch_stride = num_heads * head_stride; + } AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -76,11 +86,10 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { const int s = static_cast((ptr / num_heads) % sequence_length); const int n = static_cast(ptr % num_heads); - const int block_offset = b * sequence_length * num_heads + s * num_heads + n; - const int data_offset = block_offset * head_size; + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - const T* input_data = input_src + data_offset; - T* output_data = output_dest + data_offset; + const T* input_data = input_src + block_offset; + T* output_data = output_dest + block_offset; // Cache is (M, H/2) const int position_id = (position_ids_format == 0) diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h index cf8080800e072..7b2e8289f7b06 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h @@ -18,6 +18,7 @@ struct RotaryParameters { int num_heads; // num_heads = hidden_size / head_size int max_sequence_length; // Sequence length used by cos/sin cache int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) + bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) }; template @@ -33,8 +34,8 @@ Status CheckInputs(const T* input, // Check input const auto& input_dims = input->Shape().GetDims(); - if (input_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 dimensions, got ", + if (input_dims.size() != 3 && input_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 or 4 dimensions, got ", input_dims.size()); } // Check position_ids @@ -63,6 +64,14 @@ Status CheckInputs(const T* input, int batch_size = static_cast(input_dims[0]); int sequence_length = static_cast(input_dims[1]); int hidden_size = static_cast(input_dims[2]); + + bool transposed = false; + if (input_dims.size() == 4) { + // input is [batch, num_heads, seq, head_size] + sequence_length = static_cast(input_dims[2]); + hidden_size = static_cast(input_dims[1]) * static_cast(input_dims[3]); + transposed = true; + } int max_sequence_length = static_cast(cos_cache_dims[0]); int head_size = static_cast(cos_cache_dims[1]) * 2; int num_heads = hidden_size / head_size; @@ -111,6 +120,7 @@ Status CheckInputs(const T* input, output_parameters->num_heads = num_heads; output_parameters->max_sequence_length = max_sequence_length; output_parameters->position_ids_format = position_ids_format; + output_parameters->transposed = transposed; } return Status::OK(); @@ -118,4 +128,4 @@ Status CheckInputs(const T* input, } // namespace rotary_embedding_helper } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index c72d811170a27..320a05bb97dac 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -1,35 +1,38 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/framework/op_kernel.h" +#include "core/mlas/inc/mlas.h" +#include "core/mlas/inc/mlas_qnbit.h" +#include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" -#include "core/mlas/inc/mlas_q4.h" namespace onnxruntime { namespace contrib { class MatMulNBits final : public OpKernel { public: - MatMulNBits(const OpKernelInfo& info) : OpKernel(info) { - ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + MatMulNBits(const OpKernelInfo& info) + : OpKernel(info), + K_{narrow(info.GetAttr("K"))}, + N_{narrow(info.GetAttr("N"))}, + block_size_{narrow(info.GetAttr("block_size"))}, + nbits_{narrow(info.GetAttr("bits"))} { ORT_ENFORCE(nbits_ == 4, - "Only 4b quantization is supported for MatMulNBits op," - " additional bits support is planned."); + "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); } Status Compute(OpKernelContext* context) const override; private: - int64_t K_; - int64_t N_; - int64_t block_size_; - int64_t nbits_; - bool column_wise_quant_{true}; + const size_t K_; + const size_t N_; + const size_t block_size_; + const size_t nbits_; + const bool column_wise_quant_{true}; }; Status MatMulNBits::Compute(OpKernelContext* ctx) const { @@ -45,11 +48,60 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const auto* scales_data = scales->Data(); const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + TensorShape b_shape({static_cast(N_), static_cast(K_)}); + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + + Tensor* y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) + return Status::OK(); + + auto* y_data = y->MutableData(); + + const size_t batch_count = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(false); + + if (MlasIsSQNBitGemmAvailable(nbits_, block_size_)) { + // number of bytes or elements between adjacent matrices + size_t b_data_matrix_stride_in_bytes, b_scale_matrix_stride, b_zero_point_matrix_stride_in_bytes; + MlasBlockwiseQuantizedBufferSizes(static_cast(nbits_), static_cast(block_size_), /* columnwise */ true, + static_cast(K), static_cast(N), + b_data_matrix_stride_in_bytes, b_scale_matrix_stride, + &b_zero_point_matrix_stride_in_bytes); + + const size_t b_matrix_size = K * N; + + InlinedVector data(batch_count); + for (size_t i = 0; i < batch_count; ++i) { + const size_t b_matrix_offset = helper.RightOffsets()[i] / b_matrix_size; + + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].QuantBData = b_data + b_matrix_offset * b_data_matrix_stride_in_bytes; + data[i].QuantBScale = scales_data + b_matrix_offset * b_scale_matrix_stride; + data[i].QuantBZeroPoint = zero_points_data != nullptr + ? zero_points_data + b_matrix_offset * b_zero_point_matrix_stride_in_bytes + : nullptr; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + } + + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, data.data(), thread_pool); + + return Status::OK(); + } + + const size_t ldb = helper.Ldb(true); + AllocatorPtr allocator; - auto status = ctx->GetTempSpaceAllocator(&allocator); - ORT_RETURN_IF_ERROR(status); + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); - // dequantize b, only 4b quantization is supported for now MlasDequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output @@ -67,29 +119,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_); #endif - TensorShape b_shape({N_, K_}); - - MatMulComputeHelper helper; - ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); - - Tensor* y = ctx->Output(0, helper.OutputShape()); - - // Bail out early if the output is going to be empty - if (y->Shape().Size() == 0) - return Status::OK(); - - auto* y_data = y->MutableData(); - - const size_t max_len = helper.OutputOffsets().size(); - const size_t M = static_cast(helper.M()); - const size_t N = static_cast(helper.N()); - const size_t K = static_cast(helper.K()); - const size_t lda = helper.Lda(false); - const size_t ldb = helper.Ldb(true); - - // TODO: implement with native kernel - std::vector data(max_len); - for (size_t i = 0; i < max_len; i++) { + std::vector data(batch_count); + for (size_t i = 0; i < batch_count; i++) { data[i].BIsPacked = false; data[i].A = a_data + helper.LeftOffsets()[i]; data[i].lda = lda; @@ -101,7 +132,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { data[i].beta = 0.0f; } MlasGemmBatch(CblasNoTrans, CblasTrans, - M, N, K, data.data(), max_len, thread_pool); + M, N, K, data.data(), batch_count, thread_pool); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h index 89e2351428d40..cbe536c6ce45a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -69,6 +69,7 @@ struct Flash_fwd_params : public Qkv_params { int seqlen_q_rounded = 0; int seqlen_k_rounded = 0; int d_rounded = 0; + int rotary_dim = 0; // The scaling factors for the kernel. float scale_softmax = 0.0; @@ -92,12 +93,26 @@ struct Flash_fwd_params : public Qkv_params { index_t knew_head_stride = 0; index_t vnew_head_stride = 0; + // The cos and sin matrices for rotary embedding. + void* __restrict__ rotary_cos_ptr = nullptr; + void* __restrict__ rotary_sin_ptr = nullptr; + + // The indices to index into the KV cache. + int* __restrict__ cache_batch_idx = nullptr; + + // Local window size + int window_size_left = -1; + int window_size_right = -1; + bool is_bf16 = false; bool is_causal = false; // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. bool is_seqlens_k_cumulative = true; + + bool is_rotary_interleaved = false; + int num_splits = 0; // For split-KV version const cudaDeviceProp* dprops = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 89a27c4d2b0d3..76190aad68fdb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -35,7 +35,9 @@ void set_params_fprop(Flash_fwd_params& params, void* softmax_lse_d, float softmax_scale, bool is_causal, - bool kv_bsnh = true) { + bool kv_bsnh = true, + int window_size_left = -1, + int window_size_right = -1) { // Set the pointers and strides. params.q_ptr = q; params.k_ptr = k; @@ -102,7 +104,21 @@ void set_params_fprop(Flash_fwd_params& params, params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; + // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates + // local and causal, meaning when we have local window size params.is_causal = is_causal; + if (is_causal && (window_size_left >= 0 || window_size_right != 0)) { + params.is_causal = false; + } + if (window_size_left < 0 && window_size_right >= 0) { + window_size_left = seqlen_k; + } + if (window_size_left >= 0 && window_size_right < 0) { + window_size_right = seqlen_k; + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.is_seqlens_k_cumulative = true; } @@ -227,7 +243,8 @@ Status mha_fwd(const cudaDeviceProp& dprops, int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - bool kv_bsnh) { + bool kv_bsnh, + int local_window_size) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); @@ -247,7 +264,9 @@ Status mha_fwd(const cudaDeviceProp& dprops, softmax_lse, softmax_scale, is_causal, - kv_bsnh); + kv_bsnh, + local_window_size, + is_causal ? 0 : -1); params.dprops = &dprops; params.knew_ptr = nullptr; params.vnew_ptr = nullptr; @@ -306,7 +325,10 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, nullptr, softmax_lse, softmax_scale, - is_causal); + is_causal, + true, + -1, + is_causal ? 0 : -1); params.dprops = &dprops; params.num_splits = 0; params.softmax_lseaccum_ptr = nullptr; @@ -347,11 +369,11 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, bool past_bsnh, // otherwise bnsh int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads - void* out_accum // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded -) { - if (seqlen_q == 1) { - is_causal = false; - } // causal=true is the same as causal=false in this case + void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int local_window_size) { + // if (seqlen_q == 1) { + // is_causal = false; + // } // causal=true is the same as causal=false in this case auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); @@ -372,7 +394,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, softmax_lse, softmax_scale, is_causal, - past_bsnh); + past_bsnh, + local_window_size, + is_causal ? 0 : -1); params.dprops = &dprops; if (k != nullptr && v != nullptr) { diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 58f4304251872..efc1f565c4fa0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -54,7 +54,8 @@ Status mha_fwd(const cudaDeviceProp& dprops, int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - bool kv_bsnh = true); + bool kv_bsnh = true, + int local_window_size = -1); Status mha_varlen_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, @@ -96,8 +97,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, bool past_bsnh, // otherwise bnsh int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads - void* out_accum = nullptr // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded -); + void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int local_window_size = -1); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index eb1c794d6df54..028233f66850f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -29,47 +29,6 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -CUTE_HOST_DEVICE auto -make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(cute::Layout, cute::Int>, - cute::Stride<_1, cute::Int>>{}, - make_layout(cute::size<2>(TileShape_MNK{}))); - - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTE_HOST_DEVICE auto -make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(cute::Layout, cute::Int>, - cute::Stride<_1, cute::Int>>{}, - // TODO: Shouldn't this be size<1>? - make_layout(cute::size<2>(TileShape_MNK{}))); - // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); } - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, Tensor1& scores_sum, Tensor2& acc_o, float softmax_scale_log2) { @@ -123,7 +82,7 @@ inline __device__ void write_softmax_to_gmem( //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; @@ -144,12 +103,14 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); - if (Is_causal) { - n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); // We exit early and write 0 to gO and gLSE. // Otherwise we might read OOB elements from gK and gV. - if (n_block_max <= 0) { + if (n_block_max <= n_block_min) { const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), @@ -197,7 +158,6 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - cute::Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), cute::Shape, cute::Int>{}, make_stride(params.q_row_stride, _1{})); @@ -332,9 +292,9 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = !Is_causal + constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 - : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) @@ -364,22 +324,22 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // We don't put the masking before the matmul S = Q K^T because we don't clear sK // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // can produce Inf / NaN. - if (!Is_causal) { + if (!Is_causal && !Is_local) { if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } } else { // I can't get the stride from idx_row - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM + get<0>(idx_row(0)), - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, - kNWarps * 16); + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); } flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); @@ -390,8 +350,8 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); // Convert scores from fp32 to fp16/bf16 cute::Tensor rP = flash::convert_type(scores); @@ -408,14 +368,14 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // This check is at the end of the loop since we always have at least 1 iteration - if (n_masking_steps > 1 && n_block <= 0) { + if (n_masking_steps > 1 && n_block <= n_block_min) { --n_block; break; } } // These are the iterations where we don't need masking on S - for (; n_block >= 0; --n_block) { + for (; n_block >= n_block_min; --n_block) { cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); flash::cp_async_wait<0>(); @@ -431,7 +391,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); @@ -441,8 +401,15 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi } // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); cute::Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) @@ -543,7 +510,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; @@ -572,11 +539,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons if (m_block * kBlockM >= binfo.actual_seqlen_q) return; const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; - const int n_block_min = n_split_idx * n_blocks_per_split; + const int n_block_min = !Is_local + ? n_split_idx * n_blocks_per_split + : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); - if (Is_causal) { + if (Is_causal || Is_local) { n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); } if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 // We exit early and write 0 to gOaccum and -inf to gLSEaccum. @@ -626,10 +595,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; // We move K and V to the last block. - const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; - const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, @@ -641,16 +609,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), Shape, Int>{}, make_stride(params.v_row_stride, _1{})); - // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, - // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. - // This maps to accessing the first 64 rows of knew_ptr. - Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), - Shape, Int>{}, - make_stride(params.knew_row_stride, _1{})); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } - Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), - Shape, Int>{}, - make_stride(params.vnew_row_stride, _1{})); Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQ{}); @@ -664,11 +622,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); typename Kernel_traits::TiledMma tiled_mma; @@ -732,17 +688,129 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons } // Prologue + // Copy from Knew to K, optionally apply rotary embedding. + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + if constexpr (Append_KV) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } + // if (cute::thread(8, 0)) { print_tensor(gCos); } + // if (cute::thread(0, 0)) { print_tensor(tRgCos); } + + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { + flash::copy_w_min_idx( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + if (params.rotary_dim == 0) { + flash::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } else { + if (params.is_rotary_interleaved) { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_interleaved( + tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); + tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); + } else { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_contiguous( + tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); + tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + } + } + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + } + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + if (n_block_max > n_block_copy_min) { + tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride; + tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride; + } + } + // Read Q from gmem to smem, optionally apply rotary embedding. Tensor tQrQ = make_fragment_like(tQgQ); - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); + if (!Append_KV || params.rotary_dim == 0) { + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + } else { + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. + // We do this by setting the row stride of gCos / gSin to 0. + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + if (params.is_rotary_interleaved) { + flash::copy_rotary_interleaved( + tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim); + } else { + flash::copy_rotary_contiguous( + tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim); + } + } int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy_2_sources( - gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); // flash::cp_async_wait<0>(); @@ -760,9 +828,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = !Is_causal + constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 - : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) @@ -770,32 +838,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons flash::cp_async_wait<0>(); __syncthreads(); - if constexpr (Append_KV) { - // if (cute::thread0()) { print(tKgK); } - // if (cute::thread0()) { print(tKsK); } - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - // __syncthreads(); - // if (cute::thread0()) { print(tKgK); } - // __syncthreads(); - } - // Advance gV if (masking_step > 0) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - if (Append_KV) { - tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); - } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads - flash::copy_2_sources( - gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); } cute::cp_async_fence(); @@ -810,15 +860,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // We don't put the masking before the matmul S = Q K^T because we don't clear sK // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // can produce Inf / NaN. - if (!Is_causal) { + if (!Is_causal && !Is_local) { if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } } else { - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, - kNWarps * 16); + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); } flash::cp_async_wait<0>(); @@ -826,26 +876,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } // __syncthreads(); - // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); } - if constexpr (Append_KV) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - } - if (n_block > n_block_min) { // Advance gK - // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); } tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - if (Append_KV) { - tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); - } - // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0, - binfo.seqlen_k_cache - (n_block - 1) * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -853,8 +887,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // We have key_padding_mask so we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } // Convert scores from fp32 to fp16/bf16 @@ -879,20 +913,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons clear(acc_s); flash::cp_async_wait<0>(); __syncthreads(); - if constexpr (Append_KV) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - } // Advance gV tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - if (Append_KV) { - tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); - } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); flash::gemm( @@ -901,22 +924,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons flash::cp_async_wait<0>(); __syncthreads(); - if constexpr (Append_KV) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - } if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - if (Append_KV) { - tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); - } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0, - binfo.seqlen_k_cache - (n_block - 1) * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -924,7 +935,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) @@ -1031,7 +1049,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1047,12 +1065,12 @@ inline __device__ void compute_attn(const Params& params) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_splitkv(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1061,24 +1079,23 @@ inline __device__ void compute_attn_splitkv(const Params& params) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void combine_attn_seqk_parallel(const Params& params) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; constexpr int kMaxSplits = 1 << Log_max_splits; - constexpr int kBlockM = 16; constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNThreads = Kernel_traits::kNThreads; static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); - // static_assert(kMaxSplits <= 8, "kMaxSplits must be <= 8 for now, will extend layer"); - static_assert(kBlockM == 16 || kBlockM == 32, "kBlockM must be 16 or 32"); - static_assert(Kernel_traits::kNThreads == 128, "We assume that each block has 128 threads"); + static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); // Shared memory. // kBlockM + 1 instead of kBlockM to reduce bank conflicts. @@ -1094,10 +1111,10 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { make_stride(params.b * params.h * params.seqlen_q, _1{})); Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); - constexpr int kNLsePerThread = (kMaxSplits * kBlockM + Kernel_traits::kNThreads - 1) / Kernel_traits::kNThreads; + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; // Read the LSE values from gmem and store them in shared memory, then tranpose them. - constexpr int kRowsPerLoadLSE = Kernel_traits::kNThreads / kBlockM; + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; #pragma unroll for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadLSE + tidx / kBlockM; @@ -1165,7 +1182,12 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), Shape, Int>{}, Stride, _1>{}); - typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); Tensor tOrO = make_tensor(shape(tOgOaccum)); @@ -1183,8 +1205,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } } -// Load Oaccum in then scale and accumulate to O -#pragma unroll 2 + // Load Oaccum in then scale and accumulate to O for (int split = 0; split < params.num_splits; ++split) { flash::copy( gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index 82dfa59b8f8e7..87d189a803f8a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -10,29 +10,30 @@ namespace onnxruntime { namespace flash { -template +template __global__ void flash_fwd_kernel(Flash_fwd_params params) { + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - flash::compute_attn(params); + flash::compute_attn(params); #else (void)params; #endif } -template +template __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - flash::compute_attn_splitkv(params); + flash::compute_attn_splitkv(params); #else (void)params; #endif } -template +template __global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static_assert(Log_max_splits >= 1); - flash::combine_attn_seqk_parallel(params); + flash::combine_attn_seqk_parallel(params); #else (void)params; #endif @@ -52,20 +53,25 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { const bool is_even_K = params.d == Kernel_traits::kHeadDim; BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - // Will only return softmax if dropout, to reduce compilation time. - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - // ORT_ENFORCE(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + // ORT_ENFORCE(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + }); }); }); } @@ -82,40 +88,46 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH(params.num_splits > 1, Split, [&] { - BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { - // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); - auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV, IsEvenKConst, Split, Append_KV > ; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - } - kernel<<>>(params); + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); + auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<>>(params); + }); }); }); }); }); }); if (params.num_splits > 1) { - dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16); + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { if (params.num_splits <= 2) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 4) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 8) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 16) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 32) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 64) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 128) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } }); } @@ -130,7 +142,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) template void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 32; + constexpr static int Headdim = 32; BOOL_SWITCH(params.is_causal, Is_causal, [&] { run_flash_fwd, Is_causal>(params, stream); }); @@ -138,7 +150,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 64; + constexpr static int Headdim = 64; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower // Using block size (64 x 256) is 27% slower for seqlen=2k @@ -174,8 +186,8 @@ void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 128; - const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + constexpr static int Headdim = 128; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. @@ -201,8 +213,8 @@ void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 160; - const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + constexpr static int Headdim = 160; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For A100, H100, 128 x 32 is the fastest. // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), @@ -241,12 +253,11 @@ void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { - constexpr size_t Headdim = 224; - constexpr size_t threshold = 2 * Headdim * (128 + 2 * 64); - size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; + constexpr static int Headdim = 224; + int max_smem_per_block = params.dprops->sharedMemPerBlockOptin; // printf("max_smem_per_block = %d\n", max_smem_per_block); BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if (max_smem_per_block >= threshold) { // 112 KB + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB run_flash_fwd, Is_causal>(params, stream); } else { run_flash_fwd, Is_causal>(params, stream); @@ -262,16 +273,14 @@ void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) { - constexpr size_t Headdim = 256; - constexpr size_t min_threshold = 2 * Headdim * (128 + 2 * 64); - constexpr size_t max_threshold = 4 * Headdim * (64 + 2 * 64); + constexpr static int Headdim = 256; size_t max_smem_per_sm = params.dprops->sharedMemPerMultiprocessor; size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For A100, we want to run with 128 x 64 (128KB smem). // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. - if (max_smem_per_block >= min_threshold && max_smem_per_sm < max_threshold) { + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { run_flash_fwd, Is_causal>(params, stream); } else { run_flash_fwd, Is_causal>(params, stream); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h index 134f159e258c4..1c0ed7f2fc2e8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h @@ -161,7 +161,14 @@ struct Flash_fwd_kernel_traits : public Base { cute::Stride<_16, _1>>>; using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, - cute::Layout>{})); // Val layout, 4 vals per store + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype(make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load }; // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h index 842edf3a98a86..8017f83bbb01d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h @@ -139,10 +139,11 @@ inline __device__ void apply_mask(Tensor& tensor, const int max_ } } -template -inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, - const int max_seqlen_k, const int row_idx_offset_, - const int max_seqlen_q, const int warp_row_stride) { +template +inline __device__ void apply_mask_local(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; @@ -155,14 +156,15 @@ inline __device__ void apply_mask_causal(Tensor& tensor, const i #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; - if (col_idx >= col_idx_limit) { + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } @@ -176,6 +178,15 @@ inline __device__ void apply_mask_causal(Tensor& tensor, const i } } +template +inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_, + max_seqlen_q, warp_row_stride, -1, 0); +} + template inline __device__ void apply_mask_causal_w_idx( Tensor& tensor, Tensor const& idx_rowcol, diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h index 02042e183f808..271112c5e890a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h @@ -307,7 +307,7 @@ template inline __device__ void copy(TiledCopy tiled_copy, Tensor const& S, Tensor& D, Tensor const& identity_MN, - Tensor const& predicate_K, int max_MN = 0) { + Tensor const& predicate_K, const int max_MN = 0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA @@ -334,65 +334,161 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor const //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void copy_2_sources(TiledCopy tiled_copy, Tensor const& S0, - Tensor const& S1, +inline __device__ void copy_w_min_idx(Tensor const& S, Tensor& D, Tensor const& identity_MN, Tensor const& predicate_K, - const int max_MN = 0, const int row_idx_switch = 0) { - CUTE_STATIC_ASSERT_V(rank(S0) == Int<3>{} && rank(S1) == Int<3>{}); + const int max_MN = 0, const int min_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S0) == size<0>(D) && size<0>(S1) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S0) == size<1>(D) && size<1>(S1) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S0) == size<2>(D) && size<2>(S1) == size<2>(D)); // MMA_K - // There's no case where !Clear_OOB_K && Clear_OOB_MN - static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); -// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); } -// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); } + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } #pragma unroll - for (int m = 0; m < size<1>(S0); ++m) { - auto& S = !Is_2_sources || get<0>(identity_MN(0, m, 0)) < row_idx_switch ? S0 : S1; - if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } #pragma unroll - for (int k = 0; k < size<2>(S0); ++k) { + for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { - cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + cute::copy(S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_rotary_interleaved(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K + static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + cute::copy(Cos(_, m, k), rCos(_, m, k)); + cute::copy(Sin(_, m, k), rSin(_, m, k)); + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS) / 2; ++i) { + float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); + float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); + S_fp32(2 * i) = real; + S_fp32(2 * i + 1) = imag; + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + } + cute::copy(rS(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { cute::clear(D(_, m, k)); } } - } else if (Clear_OOB_MN) { - cute::clear(D(_, m, _)); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void copy_w_min_idx(Tensor const& S, - Tensor& D, Tensor const& identity_MN, - Tensor const& predicate_K, - const int max_MN = 0, const int min_MN = 0) { +inline __device__ void copy_rotary_contiguous(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K -// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + Tensor rS_other = make_fragment_like(rS(_, 0, 0)); #pragma unroll for (int m = 0; m < size<1>(S); ++m) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { -// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } #pragma unroll for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || predicate_K(k)) { - cute::copy(S(_, m, k), D(_, m, k)); + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; + Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); + cute::copy(gS_other, rS_other); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } + Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); + Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); + cute::copy(gCos, rCos(_, m, k)); + cute::copy(gSin, rSin(_, m, k)); + // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor S_other_fp32 = convert_type(rS_other); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS); ++i) { + S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); } + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); } } } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index f21dff08e0350..93892169f6c79 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -44,9 +44,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); num_heads_ = static_cast(num_heads); kv_num_heads_ = static_cast(kv_num_heads); - is_unidirectional_ = true; - // left_padding_ = info.GetAttrOrDefault("left_padding_last_token", 0) == 1; is_past_bsnh_ = false; // info.GetAttrOrDefault("is_past_bsnh", 1) == 1; + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); scale_ = info.GetAttrOrDefault("scale", 0.0f); #if USE_FLASH_ATTENTION @@ -92,8 +91,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { is_past_bsnh_, scale_, device_prop.maxThreadsPerBlock)); - parameters.is_unidirectional = is_unidirectional_; - // parameters.left_padding = left_padding_; + parameters.local_window_size = local_window_size_; int sequence_length = parameters.sequence_length; TensorShapeVector output_shape(3); @@ -139,6 +137,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { bool use_memory_efficient_attention = !use_flash_attention && !disable_memory_efficient_attention_ && + local_window_size_ == -1 && (parameters.head_size & 7) == 0 && parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && @@ -222,6 +221,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.k = reinterpret_cast(k_buffer.get()); data.v = reinterpret_cast(v_buffer.get()); } + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (fmha_buffer != nullptr) { + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index aade0436dc141..54a8127e29e7b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -22,8 +22,7 @@ class GroupQueryAttention final : public CudaKernel { protected: int num_heads_; // number of attention heads int kv_num_heads_; // different for k and v for group query attention - // bool left_padding_; // shifts last token to end of buffer - bool is_unidirectional_; // causal + int local_window_size_; bool is_past_bsnh_; float scale_; bool disable_flash_attention_; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 2d158155eeba9..b22ccb68c1e7b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -468,55 +468,6 @@ Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, i return CUDA_CALL(cudaGetLastError()); } -// // Kernel to append new kv to kv buffer in place -// template -// __global__ void LeftPadLast(const int max_seqlen, -// T* kv_buff, -// const int* seqlens_k) { // refers to kv buff; otherwise bnsh -// const int h = threadIdx.x; -// const int n = blockIdx.x; -// const int b = blockIdx.y; - -// const int num_heads = gridDim.x; -// const int H = blockDim.x; - -// const int present_batch_stride = max_seqlen * num_heads * H; -// const int present_row_stride = num_heads * H; -// const int present_head_stride = H; - -// // kv_buff: BTNH or BNTH with buffered memory for new -// // new_kv: BLNH - -// const int s = seqlens_k[b]; - -// const int in_offset = b * present_batch_stride + s * present_row_stride + n * present_head_stride + h; -// const int out_offset = b * present_batch_stride + (max_seqlen - 1) * present_row_stride + n * present_head_stride + h; -// kv_buff[out_offset] = kv_buff[in_offset]; -// } - -// // Concat new to kv buffer in place -// template -// Status LaunchLeftPadLast(contrib::GroupQueryAttentionParameters& parameters, -// GroupQueryAttentionData& data, -// cudaStream_t stream, -// const int max_threads_per_block) { -// const int batch_size = parameters.batch_size; -// const int sequence_length = parameters.sequence_length; -// const int num_heads = parameters.num_heads; -// const int head_size = parameters.head_size; - -// // Indicates past sequence_length of each sequence -// const int* seqlens_k = reinterpret_cast(data.seqlens_k); - -// const int H = head_size / 4; -// const dim3 grid(num_heads, batch_size, 1); -// const dim3 block(H, 1, 1); -// LeftPadLast<<>>(sequence_length, -// reinterpret_cast(data.output), -// seqlens_k); -// return CUDA_CALL(cudaGetLastError()); -// } - ////////// Launch Kernels #if USE_FLASH_ATTENTION @@ -541,7 +492,7 @@ Status FlashAttention( void* key = reinterpret_cast(const_cast(data.key)); void* value = reinterpret_cast(const_cast(data.value)); - bool is_causal = parameters.is_unidirectional; + bool is_causal = true; // Note: seqlens_k is past sequence length for flash if (parameters.is_prompt) { @@ -579,7 +530,7 @@ Status FlashAttention( seqlens_k, batch_size, num_heads, kv_num_heads, head_size, sequence_length, present_sequence_length, kv_sequence_length, scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum))); + reinterpret_cast(data.out_accum), parameters.local_window_size)); } else { // Not share buffer case // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient @@ -611,13 +562,9 @@ Status FlashAttention( seqlens_k, batch_size, num_heads, kv_num_heads, head_size, sequence_length, present_sequence_length, 0, scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum))); + reinterpret_cast(data.out_accum), parameters.local_window_size)); } - // if (parameters.left_padding && parameters.is_prompt) { - // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); - // } - DUMP_TENSOR_INIT(); DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); @@ -704,9 +651,11 @@ Status EfficientAttention( p.max_sequence_length = present_sequence_length; p.qk_head_size = head_size; p.v_head_size = head_size; - p.causal = parameters.is_unidirectional; + p.causal = true; p.scale = scale; p.seqlen_k_ptr = data.seqlens_k_total; // Note: seqlens_k is total sequence length for efficient + p.seqstart_q_ptr = nullptr; + p.seqstart_k_ptr = nullptr; p.query = query; p.key = key; p.value = value; @@ -721,10 +670,6 @@ Status EfficientAttention( p.has_custom_right_padding = true; run_memory_efficient_attention(p); - // if (parameters.left_padding && parameters.is_prompt) { - // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); - // } - DUMP_TENSOR_INIT(); DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc index b4b5dac1fbe19..2d12e975d88d7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -74,7 +74,8 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { parameters.max_sequence_length, parameters.position_ids_format, interleaved, - device_prop.maxThreadsPerBlock); + device_prop.maxThreadsPerBlock, + parameters.transposed); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index c54e72dcfce13..e1b83bd8caf54 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -27,7 +27,10 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int num_heads, const int head_size, const int position_ids_format, - const bool interleaved) { + const bool interleaved, + const int batch_stride, + const int seq_stride, + const int head_stride) { // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length // Use .x in innermost loop to access global memory efficiently @@ -37,11 +40,10 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int i = threadIdx.x; - const int block_offset = b * sequence_length * num_heads + s * num_heads + n; - const int data_offset = block_offset * head_size; + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - const T* input_data = input + data_offset; - T* output_data = output + data_offset; + const T* input_data = input + block_offset; + T* output_data = output + block_offset; // Cache is (M, H/2) const int half_head_size = head_size / 2; @@ -83,7 +85,8 @@ Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block) { + const int max_threads_per_block, + const bool transposed) { constexpr int smem_size = 0; const dim3 grid(num_heads, sequence_length, batch_size); @@ -94,10 +97,22 @@ Status LaunchRotaryEmbeddingKernel( // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` // instead. This will require kernel changes to support. + // Default input tensor shape is [batch, seq, hidden_size] + int head_stride = head_size; + int seq_stride = num_heads * head_stride; + int batch_stride = sequence_length * seq_stride; + if (transposed) { + // When transposed, input tensor shape is [batch, num_heads, seq, head_size] + seq_stride = head_size; + head_stride = sequence_length * seq_stride; + batch_stride = num_heads * head_stride; + } + assert(head_size <= max_threads_per_block); RotaryEmbeddingBSNH<<>>( output, input, cos_cache, sin_cache, position_ids, - sequence_length, num_heads, head_size, position_ids_format, interleaved + sequence_length, num_heads, head_size, position_ids_format, interleaved, + batch_stride, seq_stride, head_stride ); return CUDA_CALL(cudaGetLastError()); @@ -117,7 +132,8 @@ template Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block); + const int max_threads_per_block, + const bool transposed); template Status LaunchRotaryEmbeddingKernel( cudaStream_t stream, @@ -133,7 +149,8 @@ template Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block); + const int max_threads_per_block, + const bool transposed); } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h index 29ff48a8ad0fb..ee1ccc43dcbff 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -24,7 +24,8 @@ Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block); + const int max_threads_per_block, + const bool transposed); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 6162895d00d79..7172a28316f16 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -70,6 +70,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); @@ -260,6 +262,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h new file mode 100644 index 0000000000000..86136ea244e23 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "core/providers/cuda/shared_inc/cuda_call.h" +#include "cutlass/device_kernel.h" + +using namespace onnxruntime; + +namespace ort_fastertransformer { + +template +inline int compute_occupancy_for_kernel() { + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size > (48 << 10)) { + cudaError_t status = + cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (status == cudaError::cudaErrorInvalidValue) { + // Clear the error bit since we can ignore this. + // This should mean that smem_size > cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an + // occupancy of 0. This will cause the heuristic to ignore this configuration. + status = cudaGetLastError(); + return 0; + } + CUDA_CALL_THROW(status); + } + + int max_active_blocks = -1; + CUDA_CALL_THROW(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::Kernel, + GemmKernel::kThreadCount, smem_size)); + + return max_active_blocks; +} + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc new file mode 100644 index 0000000000000..5d4c6793ec995 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cutlass_heuristic.h" + +#include +#include +#include + +namespace ort_fastertransformer { + +struct TileShape { + int m; + int n; +}; + +TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) { + switch (tile_config) { + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + return TileShape{32, 128}; + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + return TileShape{64, 128}; + case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + return TileShape{128, 128}; + default: + ORT_THROW("[FT Error][get_grid_shape_for_config] Invalid config"); + } +} + +bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, const TileShape tile_shape, + const int split_k_factor, const size_t workspace_bytes, const bool is_weight_only) { + // All tile sizes have a k_tile of 64. + static constexpr int k_tile = 64; + + // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k + if (is_weight_only) { + if ((k % k_tile) != 0) { + return false; + } + + if ((k % split_k_factor) != 0) { + return false; + } + + const int k_elements_per_split = static_cast(k / split_k_factor); + if ((k_elements_per_split % k_tile) != 0) { + return false; + } + } + + // Check that the workspace has sufficient space for this split-k factor + const int ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); + const int ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); + const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + + if (required_ws_bytes > workspace_bytes) { + return false; + } + + return true; +} + +std::vector get_candidate_tiles(const bool is_weight_only, const bool simt_configs_only) { + std::vector simt_configs{CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; + + std::vector square_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64}; + + std::vector quant_B_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; + + const std::vector allowed_configs = is_weight_only ? quant_B_configs : square_configs; + return simt_configs_only ? simt_configs : allowed_configs; +} + +std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only) { + std::vector tiles = get_candidate_tiles(is_weight_only, simt_configs_only); + + std::vector candidate_configs; + const int min_stages = 2; + const int max_stages = sm >= 80 ? 4 : 2; + + for (const auto& tile_config : tiles) { + for (int stages = min_stages; stages <= max_stages; ++stages) { + CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages}; + candidate_configs.push_back(config); + } + } + + return candidate_configs; +} + +CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, + const std::vector& occupancies, const int64_t m, + const int64_t n, const int64_t k, const int64_t, + const int split_k_limit, const size_t workspace_bytes, + const int multi_processor_count, const int is_weight_only) { + if (occupancies.size() != candidate_configs.size()) { + ORT_THROW( + "[FT Error][estimate_best_config_from_occupancies] occpancies and " + "candidate configs vectors must have equal length."); + } + + CutlassGemmConfig best_config; + // Score will be [0, 1]. The objective is to minimize this score. + // It represents the fraction of SM resources unused in the last wave. + float config_score = 1.0f; + int config_waves = INT_MAX; + int current_m_tile = 0; + + const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; + for (int ii = 0; ii < candidate_configs.size(); ++ii) { + CutlassGemmConfig candidate_config = candidate_configs[ii]; + TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); + int occupancy = occupancies[ii]; + + if (occupancy == 0) { + continue; + } + + // Keep small tile sizes when possible. + if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile && + current_m_tile < tile_shape.m) { + continue; + } + + const int ctas_in_m_dim = static_cast((m + tile_shape.m - 1) / tile_shape.m); + const int ctas_in_n_dim = static_cast((n + tile_shape.n - 1) / tile_shape.n); + + for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) { + if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) { + const int ctas_per_wave = occupancy * multi_processor_count; + const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; + + const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; + const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave); + const float current_score = float(num_waves_total) - num_waves_fractional; + + const float score_slack = 0.1f; + if (current_score < config_score || + ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) { + config_score = current_score; + config_waves = num_waves_total; + SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = + CutlassGemmConfig{candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; + current_m_tile = tile_shape.m; + } else if (current_score == config_score && + (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor || + current_m_tile < tile_shape.m)) { + // Prefer deeper pipeline or smaller split-k + SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = + CutlassGemmConfig{candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; + current_m_tile = tile_shape.m; + config_waves = num_waves_total; + } + } + } + } + + if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) { + ORT_THROW("[FT Error] Heurisitc failed to find a valid config."); + } + + return best_config; +} + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h new file mode 100644 index 0000000000000..e70efe0503b55 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "ft_gemm_configs.h" + +#include +#include +#include + +#include "core/common/common.h" + +using namespace onnxruntime; + +namespace ort_fastertransformer { + +std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only); + +CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, + const std::vector& occupancies, const int64_t m, + const int64_t n, const int64_t k, const int64_t num_experts, + const int split_k_limit, const size_t workspace_bytes, + const int multi_processor_count, const int is_weight_only); + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h new file mode 100644 index 0000000000000..78d206bf1d9bc --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * @file epilogue_helpers.h + * + * This file includes types for the epilogues. The empty structs exist so we can signal to template + * code the type of epilogue we want to run, and let the underlying code specify the details such as + * element types, accumulator type and elements per vector access. + * + */ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/half.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_silu.h" + +namespace cutlass { +namespace epilogue { +namespace thread { + +__forceinline__ __device__ float copysignf_pos(float a, float b) { + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} + +__forceinline__ __device__ float tanh_opt(float x) { +#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) + const float exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#else + return fast_tanh(x); +#endif +} + +template <> +struct GELU_taylor { + static const bool kIsHeavy = true; + CUTLASS_DEVICE + float operator()(float const& z) const { + float k0 = float(0.7978845608028654); + float k1 = float(0.044715); + + return float( + cutlass::constants::half() * z * + (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); + } + + using Params = LinearCombinationGenericParams; + + CUTLASS_DEVICE + float operator()(float const& scalar, Params const& params_) const { return this->operator()(scalar); } +}; + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +namespace ort_fastertransformer { + +struct EpilogueOpBiasSilu {}; + +struct EpilogueOpBiasReLU {}; + +struct EpilogueOpBiasFtGelu {}; + +struct EpilogueOpBias {}; + +struct EpilogueOpNoBias {}; + +template +struct Epilogue {}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric< + cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator, + ElementAccumulator, cutlass::epilogue::thread::ScaleType::NoBetaScaling, + cutlass::FloatRoundStyle::round_to_nearest, true>; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +template +struct Epilogue { + using Op = + cutlass::epilogue::thread::LinearCombination; +}; + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h new file mode 100644 index 0000000000000..a5faad423fad9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace ort_fastertransformer { +// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape +// in the kernel layout details when doing weight only quantization. +enum class CutlassTileConfig { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // SiMT config + CtaShape128x128x8_WarpShape64x64x8, + + // TensorCore configs CTA_N = 128, CTA_K = 64 + // Warp configs for M=32 + CtaShape32x128x64_WarpShape32x32x64, + + // Warp configs for M=64 + CtaShape64x128x64_WarpShape32x64x64, + CtaShape64x128x64_WarpShape64x32x64, + + // Warp configs for M=128 + CtaShape128x128x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape128x32x64 +}; + +enum class SplitKStyle { + NO_SPLIT_K, + SPLIT_K_SERIAL, + // SPLIT_K_PARALLEL // Not supported yet +}; + +struct CutlassGemmConfig { + CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; + SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; + int split_k_factor = -1; + int stages = -1; +}; + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h new file mode 100644 index 0000000000000..311ed323cb90c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h @@ -0,0 +1,79 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Scheduler for grouped GEMM +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/matrix_coord.h" + +#include "moe_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct GemmMoeProblemVisitor + : public MoeProblemVisitor, ThreadblockShape, + GroupScheduleMode_, PrefetchTileCount, ThreadCount> { + static bool const kTransposed = Transposed; + + using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using Base = + MoeProblemVisitor; + using Params = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + + // + // Methods + // + CUTLASS_DEVICE + GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) + : Base(params_, shared_storage_, block_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h new file mode 100644 index 0000000000000..eb33a98e4246f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is + quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices + to be consumed by CUTLASS. + + Note that for int4, ThreadBlockK MUST be 64. + + */ + +#pragma once + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/platform/platform.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct LayoutDetailsB {}; + +// Volta specialiations. Volta will dequantize before STS, so we need a different operator +template +struct LayoutDetailsB { + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = 8; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. +// TODO - Switch this to column major for weights since gemms should be more performant. +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct MixedGemmArchTraits {}; + +template +struct MixedGemmArchTraits { + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassSimt; + using AccType = float; + using LayoutB = cutlass::layout::RowMajor; + + static constexpr int ElementsPerAccessA = 1; + static constexpr int ElementsPerAccessB = 1; + static constexpr int ElementsPerAccessC = 1; + static constexpr int ThreadblockK = 8; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// ========================= Volta Traits =========================== +// Volta will always dequantize after the global memory load. +// This will instantiate any HMMA tensorcore kernels for Volta. +template +struct MixedGemmArchTraits< + TypeA, TypeB, cutlass::arch::Sm70, + typename cutlass::platform::enable_if::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Turing Traits ============================== +template +struct MixedGemmArchTraits< + TypeA, TypeB, cutlass::arch::Sm75, + typename cutlass::platform::enable_if::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Ampere Traits ============================== +template +struct MixedGemmArchTraits< + TypeA, TypeB, cutlass::arch::Sm80, + typename cutlass::platform::enable_if::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Operator = typename LayoutDetails::Operator; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h new file mode 100644 index 0000000000000..bfe30b71170d8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h @@ -0,0 +1,463 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" + +#include "gemm_moe_problem_visitor.h" +#include "tile_interleaved_layout.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms. +// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global. +template +using void_t = void; + +template +struct use_dq_gemm : platform::false_type {}; + +template +struct use_dq_gemm> : platform::true_type {}; + +// SFINAE overload for dequantizing gemm +template ::value, bool>::type = true> +CUTLASS_DEVICE static void run_mma(Mma mma, int gemm_k_iterations, typename Mma::FragmentC& accum, + typename Mma::IteratorA iterator_A, typename Mma::IteratorB iterator_B, + typename Mma::FragmentC const& src_accum, ElementScale* weight_scale_ptr, + MatrixCoord scale_extent, const int thread_idx, MatrixCoord tb_offset_scale) { + typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()), weight_scale_ptr, + scale_extent, thread_idx, tb_offset_scale); + + mma(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_scale, src_accum); +} + +// SFINAE overload for normal gemm. This completely ignores the scale parameters +template ::value, bool>::type = true> +CUTLASS_DEVICE static void run_mma(Mma mma, int gemm_k_iterations, typename Mma::FragmentC& accum, + typename Mma::IteratorA iterator_A, typename Mma::IteratorB iterator_B, + typename Mma::FragmentC const& src_accum, ElementScale* weight_scale_ptr, + MatrixCoord scale_extent, const int thread_idx, MatrixCoord tb_offset_scale) { + mma(gemm_k_iterations, accum, iterator_A, iterator_B, src_accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MoeFCGemm { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = false; + + // Optional transpose + using MapArguments = + kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + static_assert(!kTransposed, "Transpose problem not supported"); + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = + GemmMoeProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // + + int problem_count; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + + int64_t* total_rows_before_expert; + int64_t gemm_n; + int64_t gemm_k; + + // Only used by device-level operator + GemmCoord* host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0), + threadblock_count(0), + ptr_A(nullptr), + ptr_B(nullptr), + weight_scales(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + total_rows_before_expert(nullptr), + gemm_n(0), + gemm_k(0), + host_problem_sizes(nullptr) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(int problem_count, int threadblock_count, typename EpilogueOutputOp::Params output_op, + const ElementA* ptr_A, const ElementB* ptr_B, const ElementScale* weight_scales, const ElementC* ptr_C, + ElementC* ptr_D, int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, + GemmCoord* host_problem_sizes = nullptr) + : problem_count(problem_count), + threadblock_count(threadblock_count), + output_op(output_op), + ptr_A(const_cast(ptr_A)), + ptr_B(const_cast(ptr_B)), + weight_scales(const_cast(weight_scales)), + ptr_C(const_cast(ptr_C)), + ptr_D(ptr_D), + total_rows_before_expert(total_rows_before_expert), + gemm_n(gemm_n), + gemm_k(gemm_k), + host_problem_sizes(nullptr) { + if (platform::is_same::value || platform::is_same::value) { + assert(weight_scales); + } + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() : ptr_A(nullptr), ptr_B(nullptr), weight_scales(nullptr), ptr_C(nullptr), ptr_D(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_visitor(args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, + tile_count), + threadblock_count(args.threadblock_count), + output_op(args.output_op), + ptr_A(args.ptr_A), + ptr_B(args.ptr_B), + weight_scales(args.weight_scales), + ptr_C(args.ptr_C), + ptr_D(args.ptr_D) {} + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) { + problem_visitor = typename ProblemVisitor::Params(args.total_rows_before_expert, args.gemm_n, args.gemm_k, + args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + weight_scales = args.weight_scales; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + MoeFCGemm() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { return Status::kSuccess; } + + static Status can_implement(Arguments const& args) { + if (args.weight_scales != nullptr) { + CUTLASS_TRACE_HOST( + "MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t"); + return Status::kInvalid; + } + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + + // The dummy template parameter is not used and exists so that we can compile this code using + // a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in + // a namespace + template + struct KernelRunner { + CUTLASS_DEVICE + static void run_kernel(Params const& params, SharedStorage& shared_storage) { CUTLASS_NOT_IMPLEMENTED(); } + }; + + template + struct KernelRunner { + CUTLASS_DEVICE + static void run_kernel(Params const& params, SharedStorage& shared_storage) { + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + static_assert(platform::is_same::value && kInterleave == 1 || + platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // + // Problem visitor. + // + ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + const int64_t gemm_k = params.problem_visitor.gemm_k; + const int64_t gemm_n = params.problem_visitor.gemm_n; + int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_offset(int(cta_idx / grid_shape.n()) * Mma::Shape::kM, + int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0); + + // Load element pointers. Exchange pointers and strides if working on the transpose + const int64_t rows_to_jump = + problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; + typename LayoutA::LongIndex ldm_A = gemm_k; + + char* byte_ptr_B = ((char*)params.ptr_B) + problem_idx * bytes_per_expert_matrix; + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); + typename LayoutB::LongIndex ldm_B = + platform::is_same::value ? gemm_n : gemm_k * kInterleave; + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + 0, + }; + + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; + + cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B, + {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, + tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n(); + run_mma(mma, gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, weight_scale_ptr, + {1, problem_size.n()}, thread_idx, tb_offset_scale); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementC* ptr_C = reinterpret_cast(params.ptr_C) + problem_idx * gemm_n; + ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; + + LayoutC layout_C(0); + LayoutC layout_D(gemm_n); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params_C, ptr_C, problem_size.mn(), thread_idx, + threadblock_offset.mn()); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params_D, ptr_D, problem_size.mn(), thread_idx, + threadblock_offset.mn()); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // Next tile + problem_visitor.advance(gridDim.x); + } + } + }; + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h new file mode 100644 index 0000000000000..60608f462fde5 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "ft_gemm_configs.h" + +namespace ort_fastertransformer { + +enum class ActivationType { Gelu, + Relu, + Silu, + GeGLU, + ReGLU, + SiGLU, + Identity, + InvalidType }; + +template +class MoeGemmRunner { + public: + MoeGemmRunner(); + + void initialize(int sm); + + void moe_gemm_bias_act(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, ActivationType activation_type, cudaStream_t stream); + + void moe_gemm(const T* A, const WeightType* B, const T* weight_scales, T* C, int64_t* total_rows_before_expert, + int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream); + + private: + template + void dispatch_to_arch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy = nullptr); + + template + void run_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cudaStream_t stream); + + private: + int sm_; + int multi_processor_count_; +}; + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu new file mode 100644 index 0000000000000..1d9a249db4237 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moe_gemm_kernels_template.h" + +namespace ort_fastertransformer { +template class MoeGemmRunner; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu new file mode 100644 index 0000000000000..7b250e6ca9060 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moe_gemm_kernels_template.h" + +namespace ort_fastertransformer { +template class MoeGemmRunner; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h new file mode 100644 index 0000000000000..66950c9b65970 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -0,0 +1,428 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Ignore CUTLASS warnings about type punning +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/arch.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" + +#include "compute_occupancy.h" +#include "epilogue_helpers.h" +#include "layout_traits_helper.h" +#include "moe_cutlass_kernel.h" + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +#include "cutlass_heuristic.h" +#include "moe_gemm_kernels.h" + +#include +#include +#include +#include + +namespace ort_fastertransformer { + +// ============================= Variable batched Gemm things =========================== +template +void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, const int multi_processor_count, + cudaStream_t stream, int* kernel_occupancy = nullptr) { + if (gemm_config.split_k_style != SplitKStyle::NO_SPLIT_K) { + ORT_THROW("[FT Error][MoeGemm] Grouped gemm does not support split-k"); + } + + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for half, float"); + + static_assert(cutlass::platform::is_same::value || + cutlass::platform::is_same::value || + cutlass::platform::is_same::value, + ""); + + // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + using ElementType_ = + typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; + using ElementType = ElementType_; + + using CutlassWeightType_ = + typename cutlass::platform::conditional::value, cutlass::half_t, + WeightType>::type; + using CutlassWeightType = CutlassWeightType_; + + // We need separate config for each architecture since we will target different tensorcore instructions. For float, + // we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + using EpilogueOp = + typename Epilogue::Op; + + // Finally, set up the kernel. + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped< + ElementType, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessA, + CutlassWeightType, typename MixedGemmArchTraits::LayoutB, cutlass::ComplexTransform::kNone, + MixedGemmArchTraits::ElementsPerAccessB, ElementType, cutlass::layout::RowMajor, ElementAccumulator, + typename MixedGemmArchTraits::OperatorClass, arch, ThreadblockShape, WarpShape, + typename MixedGemmArchTraits::InstructionShape, EpilogueOp, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, Stages, + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, typename MixedGemmArchTraits::Operator>::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; + + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + + if (kernel_occupancy != nullptr) { + *kernel_occupancy = compute_occupancy_for_kernel(); + return; + } + int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); + if (occupancy == 0) { + ORT_THROW("[FT Error][MoE Runner] GPU lacks the shared memory resources to run GroupedGEMM kernel"); + } + const int threadblock_count = multi_processor_count * occupancy; + + typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), ElementAccumulator(0.f)); + + typename GemmGrouped::Arguments args( + num_experts, threadblock_count, epilogue_op, reinterpret_cast(A), + reinterpret_cast(B), reinterpret_cast(weight_scales), + reinterpret_cast(biases), reinterpret_cast(C), total_rows_before_expert, gemm_n, + gemm_k); + + GemmGrouped gemm; + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) { + std::string err_msg = + "MoEFC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); + ORT_THROW("[FT Error][MoE Runner] " + err_msg); + } + + auto init_status = gemm.initialize(args); + if (init_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to initialize cutlass variable batched gemm. Error: " + + std::string(cutlassGetStatusString(init_status)); + ORT_THROW("[FT Error][MoE Runner] " + err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) { + std::string err_msg = + "Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status)); + ORT_THROW("[FT Error][MoE Runner] " + err_msg); + } +} + +template +struct dispatch_stages { + static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) { + std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); + ORT_THROW("[FT Error][dispatch_stages::dispatch] " + err_msg); + } +}; + +template +struct dispatch_stages { + static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) { + generic_moe_gemm_kernelLauncher( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + } +}; + +template +struct dispatch_stages 2)>::type> { + static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) { + generic_moe_gemm_kernelLauncher(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count, stream, occupancy); + } +}; + +template +void dispatch_gemm_config(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) { + switch (gemm_config.stages) { + case 2: + using DispatcherStages2 = dispatch_stages; + DispatcherStages2::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case 3: + using DispatcherStages3 = dispatch_stages; + DispatcherStages3::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case 4: + using DispatcherStages4 = dispatch_stages; + DispatcherStages4::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + default: + std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages); + ORT_THROW("[FT Error][MoE][dispatch_gemm_config] " + err_msg); + break; + } +} + +// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32. +// This overload is only enabled when T == WeightType. +template < + typename T, typename WeightType, typename arch, typename EpilogueTag, + typename std::enable_if::value && std::is_same::value>::type* = nullptr> +void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, int sm_version, + int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { + switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, + total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, C, + total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, + total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::Undefined: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config undefined."); + break; + case CutlassTileConfig::ChooseWithHeuristic: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config should have already been set by heuristic."); + break; + default: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] Config is invalid for same type MoE tensorop GEMM."); + break; + } +} + +// Tensorop GEMM overload +// Overload for quantize MoE GEMMs. We disable some warp configs here since they will not be used and we can improve +// compile time +template < + typename T, typename WeightType, typename arch, typename EpilogueTag, + typename std::enable_if::value && !std::is_same::value>::type* = nullptr> +void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, int sm_version, + int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { + switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, + total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, C, + total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<128, 32, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::Undefined: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config undefined."); + break; + case CutlassTileConfig::ChooseWithHeuristic: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config should have already been set by heuristic."); + break; + default: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] Config is invalid for mixed type tensorop GEMM."); + break; + } +} + +// This overload will handle simt gemms. It is disabled via SFINAE for tensorop. +template ::value>::type* = nullptr> +void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, int sm_version, + int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { + switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, C, + total_rows_before_expert, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::Undefined: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] gemm config undefined."); + break; + case CutlassTileConfig::ChooseWithHeuristic: + ORT_THROW( + "[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] gemm config should have already been set by heuristic."); + break; + default: + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] Unsupported config for float MoE gemm."); + break; + } +} + +template +MoeGemmRunner::MoeGemmRunner() {} + +template +void MoeGemmRunner::initialize(int sm_version) { + int device{-1}; + cudaGetDevice(&device); + sm_ = sm_version; + cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device); +} + +template +template +void MoeGemmRunner::dispatch_to_arch(const T* A, const WeightType* B, + const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, cudaStream_t stream, + int* occupancy) { + if (sm_ >= 70 && sm_ < 75) { + dispatch_moe_gemm_to_cutlass( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + sm_, multi_processor_count_, stream, occupancy); + } else if (sm_ >= 75 && sm_ < 80) { + dispatch_moe_gemm_to_cutlass( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + sm_, multi_processor_count_, stream, occupancy); + } else if (sm_ >= 80 && sm_ < 90) { + dispatch_moe_gemm_to_cutlass( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + sm_, multi_processor_count_, stream, occupancy); + } else { + ORT_THROW("[FT Error][MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); + } +} + +template +template +void MoeGemmRunner::run_gemm(const T* A, const WeightType* B, const T* weight_scales, + const T* biases, T* C, int64_t* total_rows_before_expert, + int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, cudaStream_t stream) { + static constexpr bool is_weight_only = !std::is_same::value; + static constexpr bool only_simt_configs = std::is_same::value; + std::vector candidate_configs = get_candidate_configs(sm_, is_weight_only, only_simt_configs); + std::vector occupancies(candidate_configs.size()); + + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { + dispatch_to_arch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, candidate_configs[ii], stream, &occupancies[ii]); + } + + static constexpr int workspace_bytes = 0; // No workspace for MoE GEMMs. + static constexpr int split_k_limit = 1; // MoE GEMM does not support split-k. + CutlassGemmConfig chosen_config = + estimate_best_config_from_occupancies(candidate_configs, occupancies, total_rows, gemm_n, gemm_k, num_experts, + split_k_limit, workspace_bytes, multi_processor_count_, is_weight_only); + + dispatch_to_arch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, chosen_config, stream); +} + +template +void MoeGemmRunner::moe_gemm_bias_act(const T* A, const WeightType* B, const T* weight_scales, + const T* biases, T* C, int64_t* total_rows_before_expert, + int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, ActivationType activation_type, + cudaStream_t stream) { + switch (activation_type) { + case ActivationType::Relu: + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); + break; + case ActivationType::Gelu: + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, stream); + break; + case ActivationType::Silu: + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); + break; + case ActivationType::Identity: + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); + break; + case ActivationType::InvalidType: + ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); + break; + default: { + ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); + } + } +} + +template +void MoeGemmRunner::moe_gemm(const T* A, const WeightType* B, const T* weight_scales, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, + int64_t gemm_k, int num_experts, cudaStream_t stream) { + run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); +} + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu new file mode 100644 index 0000000000000..398ce4ee9880f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -0,0 +1,830 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +// Ignore CUTLASS warnings about type punning +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +#include "moe_kernel.h" + +#if CUDA_VERSION >= 11000 +#include +#include +#include +#else +#include "cub/cub.cuh" +#include "cub/device/device_radix_sort.cuh" +#include "cub/util_type.cuh" +#endif + +namespace ort_fastertransformer { + +static constexpr int WARP_SIZE = 32; + +// ====================== Softmax things =============================== +// We have our own implementation of softmax here so we can support transposing the output +// in the softmax kernel when we extend this module to support expert-choice routing. +template +__launch_bounds__(TPB) __global__ + void moe_softmax(const T* input, const bool* finished, T* output, const int num_cols) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + + const int thread_row_offset = blockIdx.x * num_cols; + + cub::Sum sum; + float threadData(-FLT_MAX); + + // Don't touch finished rows. + if ((finished != nullptr) && finished[blockIdx.x]) { + return; + } + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData = max(static_cast(input[idx]), threadData); + } + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) { + float_max = maxElem; + } + __syncthreads(); + + threadData = 0; + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData += exp((static_cast(input[idx]) - float_max)); + } + + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + + if (threadIdx.x == 0) { + normalizing_factor = 1.f / Z; + } + __syncthreads(); + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + output[idx] = T(val); + } +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 +template +__launch_bounds__(TPB) __global__ void moe_top_k(const T*, const bool*, T*, int*, int*, int, const int) { + // Does not support pre-Kepler architectures + ; +} +#else +template +__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, const bool* finished, T* output, + int* indices, int* source_rows, int num_experts, int k) { + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + int num_rows = gridDim.x; + const int block_row = blockIdx.x; + + const bool should_process_row = finished ? !finished[block_row] : true; + const int thread_read_offset = blockIdx.x * num_experts; + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + const int idx = k * block_row + k_idx; + output[idx] = result_kvp.value; + indices[idx] = should_process_row ? result_kvp.key : num_experts; + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + } +} +#endif + +// ====================== TopK softmax things =============================== + +/* + A Top-K gating softmax written to exploit when the number of experts in the MoE layers + are a small power of 2. This allows us to cleanly share the rows among the threads in + a single warp and eliminate communication between warps (so no need to use shared mem). + + It fuses the softmax, max and argmax into a single kernel. + + Limitations: + 1) This implementation is intended for when the number of experts is a small power of 2. + 2) This implementation assumes k is small, but will work for any k. +*/ + +template +__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ + void topk_gating_softmax(const T* input, const bool* finished, T* output, int num_rows, int* indices, + int* source_rows, int k) { + // We begin by enforcing compile time assertions and setting up compile time constants. + static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); + static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); + static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); + static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); + + // Number of bytes each thread pulls in per load + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static constexpr int ELTS_PER_ROW = NUM_EXPERTS; + static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; + static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + + // Restrictions based on previous section. + static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); + static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); + static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); + static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); + + // We have NUM_EXPERTS elements per row. We specialize for small #experts + static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; + static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; + + // Restrictions for previous section. + static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); + + // ===================== From this point, we finally start computing run-time variables. ======================== + + // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. + // This, each block processes a chunk of rows. We start by computing the start row for each block. + const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + + // Now, using the base row per thread block, we compute the base row per warp. + const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + + // The threads in a warp are split into sub-groups that will work on a row. + // We compute row offset for each thread sub-group + const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + const int thread_row = warp_base_row + thread_row_in_warp; + + // Threads with indices out of bounds should early exit here. + if (thread_row >= num_rows) return; + const bool should_process_row = finished ? !finished[thread_row] : true; + + // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the + // row it will read. + const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + + // Now, we compute the group each thread belong to in order to determine the first column to start loads. + const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; + const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + + // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, + // this can support all powers of 2 up to 16. + using AccessType = cutlass::AlignedArray; + + // Finally, we pull in the data from global mem + cutlass::Array row_chunk_input; + AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk_input); + const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + + using ComputeType = float; + using Converter = cutlass::NumericArrayConverter; + Converter compute_type_converter; + cutlass::Array row_chunk = compute_type_converter(row_chunk_input); + + // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just + // convert to float afterwards for the exp + sum reduction. + ComputeType thread_max = row_chunk[0]; +#pragma unroll + for (int ii = 1; ii < VPT; ++ii) { + thread_max = max(thread_max, row_chunk[ii]); + } + +// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); + } + + // From this point, thread max in all the threads have the max within the row. + // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. + float row_sum = 0; +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = expf(row_chunk[ii] - thread_max); + row_sum += row_chunk[ii]; + } + +// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); + } + + // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables + // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to + // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. + // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the + // argmax after computing the softmax. + const float reciprocal_row_sum = 1.f / row_sum; + +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; + } + + // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along + // with the max index.​ + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + for (int k_idx = 0; k_idx < k; ++k_idx) { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; +#pragma unroll + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) { +#pragma unroll + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + + // No check on the experts here since columns with the smallest index are processed first and only + // updated if > (not >=) + if (val > max_val) { + max_val = val; + expert = col + ii; + } + } + } + +// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max. +// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can +// then blank out their max with -inf and the warp can run more iterations... +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); + int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this way + if (other_max > max_val || (other_max == max_val && other_expert < expert)) { + max_val = other_max; + expert = other_expert; + } + } + + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) { + // The lead thread from each sub-group will write out the final results to global memory. (This will be a + // single) thread per row of the input/output matrices. + const int idx = k * thread_row + k_idx; + output[idx] = T(max_val); + indices[idx] = should_process_row ? expert : NUM_EXPERTS; + source_rows[idx] = k_idx * num_rows + thread_row; + } + + // Finally, we clear the value in the thread with the current max if there is another iteration to run. + if (k_idx + 1 < k) { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = ComputeType(-10000.f); + } + } + } +} + +namespace detail { +// Constructs some constants needed to partition the work across threads at compile time. +template +struct TopkConstants { + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); + static constexpr int VECs_PER_THREAD = std::max(1, (int)EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +}; +} // namespace detail + +template +void topk_gating_softmax_launcher_helper(const T* input, const bool* finished, T* output, int* indices, int* source_row, + int num_rows, int num_experts, int k, cudaStream_t stream) { + static constexpr unsigned long MAX_BYTES_PER_LDG = 16; + + static constexpr int BYTES_PER_LDG = std::min((int)MAX_BYTES_PER_LDG, (int)sizeof(T) * EXPERTS); + using Constants = detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + + dim3 block_dim(WARP_SIZE, WARPS_PER_TB); + topk_gating_softmax + <<>>(input, finished, output, num_rows, indices, source_row, k); +} + +template +void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_output, + int* indices, int* source_row, int num_rows, int num_experts, + int k, cudaStream_t stream) { + static constexpr int WARPS_PER_TB = 4; + + switch (num_experts) { + case 2: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 4: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 8: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 16: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 32: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 64: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 128: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + case 256: { + topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, + num_experts, k, stream); + break; + } + default: { + static constexpr int TPB = 256; + moe_softmax<<>>(input, finished, softmax_temp_output, num_experts); + moe_top_k + <<>>(softmax_temp_output, finished, output, indices, source_row, num_experts, k); + } + } +} + +// ========================== CUB Sorting things ==================================== +CubKeyValueSorter::CubKeyValueSorter() : num_experts_(0), num_bits_(sizeof(int) * 8) {} + +CubKeyValueSorter::CubKeyValueSorter(int num_experts) + : num_experts_(num_experts), num_bits_((int)log2(num_experts) + 1) {} + +void CubKeyValueSorter::update_num_experts(int num_experts) { + num_experts_ = num_experts; + num_bits_ = (int)log2(num_experts) + 1; +} + +size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs) { + num_key_value_pairs_ = num_key_value_pairs; + size_t required_storage = 0; + int* null_int = nullptr; + cub::DeviceRadixSort::SortPairs(NULL, required_storage, null_int, null_int, null_int, null_int, + (int)num_key_value_pairs, 0, num_bits_); + return required_storage; +} + +void CubKeyValueSorter::run(void* workspace, const size_t workspace_size, const int* keys_in, int* keys_out, + const int* values_in, int* values_out, const size_t num_key_value_pairs, + cudaStream_t stream) { + size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs); + size_t actual_ws_size = workspace_size; + + if (expected_ws_size > workspace_size) { + ORT_THROW("Error. The allocated workspace is too small to run this problem. Expected workspace size of at least ", + expected_ws_size, " but got problem size ", workspace_size, "\n"); + } + cub::DeviceRadixSort::SortPairs(workspace, actual_ws_size, keys_in, keys_out, values_in, values_out, + (int)num_key_value_pairs, 0, num_bits_, stream); +} + +// ============================== Infer GEMM sizes ================================= +__device__ inline int find_total_elts_leq_target(const int* sorted_indices, const int arr_length, const int target) { + int64_t low = 0, high = arr_length - 1, target_location = -1; + while (low <= high) { + int64_t mid = (low + high) / 2; + + if (sorted_indices[mid] > target) { + high = mid - 1; + } else { + low = mid + 1; + target_location = mid; + } + } + return target_location + 1; +} + +// Sets up the gemm assuming the inputs, experts and outputs are stored in row major order. +// Assumes we want to perform output = matmul(inputs, experts) + bias +__global__ void compute_total_rows_before_expert_kernel(const int* sorted_experts, const int sorted_experts_len, + const int64_t num_experts, int64_t* total_rows_before_expert) { + // First, compute the global tid. We only need 1 thread per expert. + const int expert = blockIdx.x * blockDim.x + threadIdx.x; + if (expert >= num_experts) return; + + // This should construct the last index where each expert occurs. + total_rows_before_expert[expert] = find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); +} + +template +CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version) { + moe_gemm_runner_.initialize(sm_version); +} + +template +size_t CutlassMoeFCRunner::getWorkspaceSize(int num_rows, const int hidden_size, + const int inter_size, int num_experts, + int k) { + const int buf_size = static_cast(pad_to_multiple_of_16(k * num_rows * hidden_size)); + const int interbuf_size = static_cast(pad_to_multiple_of_16(k * num_rows * inter_size)); + const int padded_experts = static_cast(pad_to_multiple_of_16(num_experts)); + const int num_moe_inputs = static_cast(pad_to_multiple_of_16(k * num_rows)); + int num_softmax_outs = 0; + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if (!is_pow_2 || num_experts > 256) { + num_softmax_outs = static_cast(pad_to_multiple_of_16(num_rows * num_experts)); + } + + // softmax output, permuted_rows and permuted_experts have moved to outside of moe kernel, allocate them + // in Encoder or Decoder before invoking FfnLayer forward. + size_t total_ws_bytes = 3 * num_moe_inputs * sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ + total_ws_bytes += buf_size * sizeof(T); // permuted_data + total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ + total_ws_bytes += num_softmax_outs * sizeof(T); + const int bytes_for_fc1_result = interbuf_size * sizeof(T); + const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows))); + sorter_.update_num_experts(num_experts); + + int bytes_for_intermediate_and_sorting = bytes_for_fc1_result; + if (sorter_ws_size_bytes > bytes_for_fc1_result) { + int remaining_bytes = static_cast(pad_to_multiple_of_16(sorter_ws_size_bytes - bytes_for_fc1_result)); + bytes_for_intermediate_and_sorting += remaining_bytes; + } + + total_ws_bytes += bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub sorting workspace + return total_ws_bytes; +} + +template +void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, int num_rows, + const int hidden_size, const int inter_size, + int num_experts, int k) { + const int buf_size = static_cast(pad_to_multiple_of_16(k * num_rows * hidden_size)); + const int interbuf_size = static_cast(pad_to_multiple_of_16(k * num_rows * inter_size)); + const int padded_experts = static_cast(pad_to_multiple_of_16(num_experts)); + const int num_moe_inputs = static_cast(pad_to_multiple_of_16(k * num_rows)); + // const int num_softmax_outs = pad_to_multiple_of_16(num_rows * num_experts); + + source_rows_ = (int*)ws_ptr; + permuted_rows_ = source_rows_ + num_moe_inputs; + permuted_experts_ = permuted_rows_ + num_moe_inputs; + permuted_data_ = (T*)(permuted_experts_ + num_moe_inputs); + + total_rows_before_expert_ = (int64_t*)(permuted_data_ + buf_size); + + fc1_result_ = (T*)(total_rows_before_expert_ + padded_experts); + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if (!is_pow_2 || num_experts > 256) { + softmax_out_ = (T*)(fc1_result_ + interbuf_size); + } else { + softmax_out_ = nullptr; + } +} + +template +void CutlassMoeFCRunner::run_moe_fc( + const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, + const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights, + const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, + int k, char* workspace_ptr, T* fc2_result, const bool* finished, int active_rows, T* expert_scales, + int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream) { + static constexpr bool scales_required = + std::is_same::value || std::is_same::value; + + if (scales_required) { + if (fc1_scales == nullptr) { + ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for first matmul is a null pointer"); + } else if (fc2_scales == nullptr) { + ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for second matmul is a null pointer"); + } + } else { + if (fc1_scales != nullptr) { + ORT_THROW("[FT Error][Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC1"); + } else if (fc2_scales != nullptr) { + ORT_THROW("[FT Error][Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC2"); + } + } + + configure_ws_ptrs(workspace_ptr, num_rows, hidden_size, inter_size, num_experts, k); + topk_gating_softmax_kernelLauncher(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row, + source_rows_, num_rows, num_experts, k, stream); + + const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows))); + sorter_.run((void*)fc1_result_, sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, source_rows_, + permuted_rows_, k * num_rows, stream); + + initialize_moe_routing_kernelLauncher(input_activations, permuted_data_, permuted_rows_, + expanded_source_row_to_expanded_dest_row, num_rows, active_rows, hidden_size, k, + stream); + + const int expanded_active_expert_rows = k * active_rows; + compute_total_rows_before_expert(permuted_experts_, expanded_active_expert_rows, num_experts, + total_rows_before_expert_, stream); + + moe_gemm_runner_.moe_gemm_bias_act(permuted_data_, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_result_, + total_rows_before_expert_, expanded_active_expert_rows, inter_size, hidden_size, + num_experts, fc1_activation_type, stream); + + moe_gemm_runner_.moe_gemm(fc1_result_, fc2_expert_weights, fc2_scales, fc2_result, total_rows_before_expert_, + expanded_active_expert_rows, hidden_size, inter_size, num_experts, stream); +} + +template +void CutlassMoeFCRunner::run_moe_fc( + const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, + const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights, + const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, + int k, char* workspace_ptr, T* fc2_result, T* expert_scales, int* expanded_source_row_to_expanded_dest_row, + int* expert_for_source_row, cudaStream_t stream) { + run_moe_fc(input_activations, gating_output, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_activation_type, + fc2_expert_weights, fc2_scales, num_rows, hidden_size, inter_size, num_experts, k, workspace_ptr, + fc2_result, nullptr, num_rows, expert_scales, expanded_source_row_to_expanded_dest_row, + expert_for_source_row, stream); +} + +template +void CutlassMoeFCRunner::compute_total_rows_before_expert(const int* sorted_indices, + const int total_indices, + int num_experts, + int64_t* total_rows_before_expert, + cudaStream_t stream) { + const int threads = std::min(1024, num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + compute_total_rows_before_expert_kernel<<>>(sorted_indices, total_indices, num_experts, + total_rows_before_expert); +} + +// ========================== Permutation things ======================================= + +// Duplicated and permutes rows for MoE. In addition, reverse the permutation map to help with finalizing routing. + +// "expanded_x_row" simply means that the number of values is num_rows x k. It is "expanded" since we will have to +// duplicate some rows in the input matrix to match the dimensions. Duplicates will always get routed to separate +// experts in the end. + +// Note that the expanded_dest_row_to_expanded_source_row map referred to here has indices in the range (0, +// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input +// all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we simply take the modulus +// of the expanded index. + +template +__global__ void initialize_moe_routing_kernel(const T* unpermuted_input, T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, int num_rows, + int active_rows, int cols) { + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the + // reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. + const int expanded_dest_row = blockIdx.x; + const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + if (threadIdx.x == 0) { + expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row; + } + + if (blockIdx.x < active_rows) { + // Duplicate and permute rows + const int source_row = expanded_source_row % num_rows; + + const T* source_row_ptr = unpermuted_input + source_row * cols; + T* dest_row_ptr = permuted_output + expanded_dest_row * cols; + + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + dest_row_ptr[tid] = source_row_ptr[tid]; + } + } +} + +template +void initialize_moe_routing_kernelLauncher(const T* unpermuted_input, T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, int num_rows, + int active_rows, int cols, int k, cudaStream_t stream) { + const int blocks = num_rows * k; + const int threads = std::min(cols, 1024); + initialize_moe_routing_kernel + <<>>(unpermuted_input, permuted_output, expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, num_rows, k * active_rows, cols); +} + +// Final kernel to unpermute and scale +// This kernel unpermutes the original data, does the k-way reduction and performs the final skip connection. +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 +template +__global__ void finalize_moe_routing_kernel(const T*, T*, const T*, const T*, const T*, const T*, const int*, + const int*, int, const int) { + // Does not support pre-Kepler architectures + ; +} +#else +template +__global__ void finalize_moe_routing_kernel(const T* expanded_permuted_rows, T* reduced_unpermuted_output, + const T* skip_1, const T* skip_2, const T* bias, const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int cols, int k) { + const int original_row = blockIdx.x; + int num_rows = gridDim.x; + T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols; + + const T* skip_1_row_ptr = nullptr; + if (RESIDUAL_NUM == 1) { + skip_1_row_ptr = skip_1 + original_row * cols; + } + const T* skip_2_row_ptr = nullptr; + if (RESIDUAL_NUM == 2) { + skip_2_row_ptr = skip_2 + original_row * cols; + } + + for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) { + T thread_output; + if (RESIDUAL_NUM == 0) { + thread_output = T(0); + } else if (RESIDUAL_NUM == 1) { + thread_output = skip_1_row_ptr[tid]; + } else if (RESIDUAL_NUM == 2) { + thread_output = skip_1_row_ptr[tid] + skip_2_row_ptr[tid]; + } + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int expanded_original_row = original_row + k_idx * num_rows; + const int expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row]; + + const int64_t k_offset = original_row * k + k_idx; + const T row_scale = scales[k_offset]; + const T* expanded_permuted_rows_row_ptr = expanded_permuted_rows + expanded_permuted_row * cols; + + const int expert_idx = expert_for_source_row[k_offset]; + const T* bias_ptr = bias + expert_idx * cols; + + thread_output = thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] + bias_ptr[tid]); + } + reduced_row_ptr[tid] = thread_output; + } +} +#endif + +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* bias, + const T* scales, const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, + int k, cudaStream_t stream) { + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + finalize_moe_routing_kernel<<>>( + expanded_permuted_rows, reduced_unpermuted_output, nullptr, nullptr, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); +} + +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip, + const T* bias, const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, + int k, cudaStream_t stream) { + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + finalize_moe_routing_kernel + <<>>(expanded_permuted_rows, reduced_unpermuted_output, skip, nullptr, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); +} + +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1, + const T* skip_2, const T* bias, const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, + int k, cudaStream_t stream) { + const int blocks = num_rows; + const int threads = std::min(cols, 1024); + if (skip_2 == nullptr) { + finalize_moe_routing_kernel<<>>( + expanded_permuted_rows, reduced_unpermuted_output, skip_1, skip_2, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); + } else { + finalize_moe_routing_kernel<<>>( + expanded_permuted_rows, reduced_unpermuted_output, skip_1, skip_2, bias, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k); + } +} + +// ========================= TopK Softmax specializations =========================== +template void topk_gating_softmax_kernelLauncher(const float*, const bool*, float*, float*, int*, int*, int, + int, int, cudaStream_t); +template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int, + int, int, cudaStream_t); + +// ==================== Variable batched GEMM specializations ================================== +template class CutlassMoeFCRunner; +template class CutlassMoeFCRunner; + +// ===================== Specializations for init routing ========================= +template void initialize_moe_routing_kernelLauncher(const float*, float*, const int*, int*, int, int, + int, int, cudaStream_t); +template void initialize_moe_routing_kernelLauncher(const half*, half*, const int*, int*, int, int, + int, int, cudaStream_t); + +// ==================== Specializations for final routing =================================== +template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const int*, + const int*, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const int*, const int*, + int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const float*, + const int*, const int*, int, int, int, + cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, const int*, + const int*, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const float*, + const float*, const int*, const int*, int, int, int, + cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, + const half*, const int*, const int*, int, int, int, + cudaStream_t); + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h new file mode 100644 index 0000000000000..5cefe4fa5dc47 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "moe_gemm_kernels.h" +#include + +#include "core/common/common.h" + +using namespace onnxruntime; + +namespace ort_fastertransformer { + +static inline size_t pad_to_multiple_of_16(size_t input) { + static constexpr int ALIGNMENT = 16; + return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +} + +/* + Launches the topk gating softmax required for the MoE layers. + + Params: + input - a [num_rows x num_experts] + finished - [num_rows] vector with 1 if the sentence at this row is done translating and 0 otherwise. + output - a buffer of shape [num_rows x k] containing the top-k values of the softmax for each row. + indices - a matrix of shape [num_rows x k] containing the top-k experts each row should get routed to. + source_rows - a matrix of shape [num_rows x k] used internally for permuting. source_rows[row][k] = k * num_rows + + row. It is constructed like this so we can track where each of the original rows end up in order to perform the + "k-way" reduction later in the routing. + + num_rows - The number of rows in the matrix + num_experts - The number of expert layers present + k - k value in topk +*/ +template +void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_out, + int* indices, int* source_row, int num_rows, int num_experts, + int k, cudaStream_t stream); + +class CubKeyValueSorter { + public: + CubKeyValueSorter(); + + CubKeyValueSorter(int num_experts); + + void update_num_experts(int num_experts); + + size_t getWorkspaceSize(const size_t num_key_value_pairs); + + void run(void* workspace, const size_t workspace_size, const int* keys_in, int* keys_out, const int* values_in, + int* values_out, const size_t num_key_value_pairs, cudaStream_t stream); + + private: + size_t num_key_value_pairs_; + int num_experts_; + int num_bits_; +}; + +template +void initialize_moe_routing_kernelLauncher(const T* unpermuted_input, T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, int num_rows, + int active_rows, int cols, int k, cudaStream_t stream); + +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* bias, + const T* scales, const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, + int k, cudaStream_t stream); + +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip, + const T* bias, const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, + int k, cudaStream_t stream); + +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1, + const T* skip_2, const T* bias, const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, int num_rows, int cols, + int k, cudaStream_t stream); + +// Assumes inputs activations are row major. Weights need to be preprocessed by th_op/weight_quantize.cc . +// Nested in a class to avoid multiple calls to cudaGetDeviceProperties as this call can be expensive. +// Avoid making several duplicates of this class. +template +class CutlassMoeFCRunner { + public: + CutlassMoeFCRunner(int sm_version); + + size_t getWorkspaceSize(int num_rows, int hidden_size, int inter_size, int num_experts, int k); + + void run_moe_fc(const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, + const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type, + const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, int hidden_size, + int inter_size, int num_experts, int k, char* workspace_ptr, T* fc2_result, + T* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, + cudaStream_t stream); + + void run_moe_fc(const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, + const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type, + const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, int hidden_size, + int inter_size, int num_experts, int k, char* workspace_ptr, T* fc2_result, + const bool* finished, int active_rows, T* expert_scales, + int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream); + + void compute_total_rows_before_expert(const int* sorted_indices, int total_indices, int num_experts, + int64_t* total_rows_before_expert, cudaStream_t stream); + + private: + void configure_ws_ptrs(char* ws_ptr, int num_rows, int hidden_size, int inter_size, int num_experts, int k); + + private: + CubKeyValueSorter sorter_; + MoeGemmRunner moe_gemm_runner_; + + // Pointers + int* source_rows_; + int* permuted_rows_; + int* permuted_experts_; + char* sorter_ws_; + T* permuted_data_; + T* softmax_out_; + + int64_t* total_rows_before_expert_; + + T* fc1_result_; +}; + +template +class CutlassMoeFCRunner::value>> { + public: + CutlassMoeFCRunner(int sm_version); + + size_t getWorkspaceSize(int num_rows, int hidden_size, int inter_size, int num_experts, int k) { + return 0; + } +}; + +} // namespace ort_fastertransformer \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h new file mode 100644 index 0000000000000..00f977c615df6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h @@ -0,0 +1,290 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Base scheduler for grouped problems, using MoE +*/ + +#pragma once + +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct BaseMoeProblemVisitor { + using ThreadblockShape = ThreadblockShape_; + + struct ProblemInfo { + static int32_t const kNoPrefetchEntry = -1; + int32_t problem_idx; + int32_t problem_start; + + CUTLASS_DEVICE + ProblemInfo() : problem_idx(kNoPrefetchEntry), problem_start(kNoPrefetchEntry) {} + + CUTLASS_DEVICE + ProblemInfo(int32_t problem_idx_, int32_t problem_start_) + : problem_idx(problem_idx_), problem_start(problem_start_) {} + }; + + struct Params { + int64_t const* last_row_for_problem; + int64_t gemm_n; + int64_t gemm_k; + int32_t problem_count; + void const* workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params() + : last_row_for_problem(nullptr), gemm_n(0), gemm_k(0), problem_count(0), workspace(nullptr), tile_count(0) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Params(int64_t const* last_row_for_problem, int64_t gemm_n, int64_t gemm_k, int32_t problem_count, + void const* workspace = nullptr, int32_t tile_count = 0) + : last_row_for_problem(last_row_for_problem), + gemm_n(gemm_n), + gemm_k(gemm_k), + problem_count(problem_count), + workspace(workspace), + tile_count(tile_count) {} + }; + + Params const& params; + int32_t tile_idx; + int32_t problem_tile_start; + int32_t problem_idx; + + // + // Methods + // + CUTLASS_DEVICE + BaseMoeProblemVisitor(Params const& params_, int32_t block_idx) + : params(params_), tile_idx(block_idx), problem_tile_start(0), problem_idx(0) {} + + /// Get the grid shape + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { + return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1); + } + + /// Gets the global tile index + CUTLASS_HOST_DEVICE + int32_t tile_index() const { return tile_idx; } + + /// Gets the index of the problem + CUTLASS_HOST_DEVICE + int32_t problem_index() const { return problem_idx; } + + CUTLASS_HOST_DEVICE + int32_t threadblock_idx() const { return tile_idx - problem_tile_start; } + + CUTLASS_DEVICE + void advance(int32_t grid_size) { tile_idx += grid_size; } + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) { + ProblemSizeHelper::possibly_transpose_problem(problem); + } + + /// Returns the problem size for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size() const { return problem_size(problem_idx); } + + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size(int idx) const { + const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; + const int64_t current_problem_row = params.last_row_for_problem[idx]; + const int64_t gemm_m = current_problem_row - prev_problem_row; + GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k)); + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + CUTLASS_HOST_DEVICE + static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { return ProblemSizeHelper::tile_count(grid); } + + static int32_t group_tile_count(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count) { + int32_t total_tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) { + auto problem = host_problem_sizes_ptr[i]; + possibly_transpose_problem(problem); + auto grid = grid_shape(problem); + total_tiles += tile_count(grid); + } + + return total_tiles; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MoeProblemVisitor; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// ProblemVisitor that performs all scheduling on device +// +template +struct MoeProblemVisitor : public BaseMoeProblemVisitor { + using Base = BaseMoeProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + static bool const kRequiresPrecomputation = false; + static int const kThreadsPerWarp = 32; + + struct SharedStorage {}; + + // Final tile of the problem loaded by this thread. Each thread will hold + // a separate value. + int32_t problem_ending_tile; + + SharedStorage& shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) + : Base(params_, block_idx), problem_ending_tile(0), shared_storage(shared_storage_) { + this->problem_idx = -1 * kThreadsPerWarp; + this->problem_tile_start = 0; + } + + CUTLASS_DEVICE + bool next_tile() { + // Check whether the tile to compute is within the range of the current problem. + int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp); + if (this->tile_idx < problem_tile_end) { + return true; + } + + // Check whether the tile to compute is within the current group of problems fetched by the warp. + // The last tile for this group is the final tile of the problem held by the final thread in the warp. + int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); + + // Keep the starting problem for this group in `problem_idx`. This is done to reduce + // register pressure. The starting problem for this group is simply the first problem + // in the group most recently fetched by the warp. + int32_t& group_problem_start = this->problem_idx; + group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp; + + // Keep the starting tile for this group in `problem_tile_start`. This is done to reduce + // register pressure. + int32_t& group_tile_start = this->problem_tile_start; + + // Each thread in the warp processes a separate problem to advance until + // reaching a problem whose starting tile is less less than tile_idx. + while (group_tile_end <= this->tile_idx) { + group_problem_start += kThreadsPerWarp; + if (group_problem_start > this->params.problem_count) { + return false; + } + + // Since `group_tile_start` is a reference to `this->problem_tile_start`, this + // also sets `this->problem_tile_start`. The fact that `this->problem_tile_start` + // is also set here is used later in `next_tile`. + group_tile_start = group_tile_end; + + int lane_idx = threadIdx.x % kThreadsPerWarp; + int32_t lane_problem = group_problem_start + lane_idx; + + // Compute the number of tiles in the problem assigned to each thread. + problem_ending_tile = 0; + if (lane_problem < this->params.problem_count) { + cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + problem_ending_tile = this->tile_count(grid); + } + + // Compute a warp-wide inclusive prefix sum to compute the ending tile index of + // each thread's problem. + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kThreadsPerWarp; i <<= 1) { + int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i); + if (lane_idx >= i) { + problem_ending_tile += val; + } + } + + // The total tile count for this group is now in the final position of the prefix sum + int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); + + problem_ending_tile += group_tile_start; + group_tile_end += tiles_in_group; + } + + // The next problem to process is the first one that does not have ending tile position + // that is greater than or equal to tile index. + int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx)); + + this->problem_idx = group_problem_start + problem_idx_in_group; + + // The starting tile for this problem is the ending tile of the previous problem. In cases + // where `problem_idx_in_group` is the first problem in the group, we do not need to reset + // `problem_tile_start`, because it is set to the previous group's ending tile in the while + // loop above. + if (problem_idx_in_group > 0) { + this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1); + } + + return true; + } + + static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count, + int32_t block_count) { + return 0; + } + + static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count, + int32_t block_count, void* host_workspace_ptr) {} +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h new file mode 100644 index 0000000000000..3505bea24e4d9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h @@ -0,0 +1,61 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines new layouts needed for MoE +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/pitch_linear_coord.h" + +namespace cutlass { +namespace layout { + +template +class ColumnMajorTileInterleave { + static constexpr int kRowsPerTile = RowsPerTile; + static constexpr int kColumnsInterleaved = ColumnsInterleaved; +}; + +template +struct IsColumnMajorTileInterleave { + static constexpr bool value = false; +}; + +template +struct IsColumnMajorTileInterleave> { + static constexpr bool value = true; +}; + +} // namespace layout +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc new file mode 100644 index 0000000000000..6f2ffe7a0cc43 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -0,0 +1,197 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_common.h" +#include "moe.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + MoE, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(0, 0) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + MoE); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +using namespace ONNX_NAMESPACE; + +template +Status MoE::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* router_probs = context->Input(1); + const Tensor* fc1_experts_weights = context->Input(2); + const Tensor* fc2_experts_weights = context->Input(3); + const Tensor* fc1_experts_bias_optional = context->Input(4); + const Tensor* fc2_experts_bias_optional = context->Input(5); + + const auto& input_dims = input->Shape().GetDims(); + const auto& router_probs_dims = router_probs->Shape().GetDims(); + const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); + const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims(); + + const int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1]; + const int64_t hidden_size = input_dims[input_dims.size() - 1]; + const int64_t num_experts = fc1_experts_weights_dims[0]; + const int64_t inter_size = fc1_experts_weights_dims[2]; + + // TODO: refactor to helper function. + if (fc1_experts_weights_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims must be 3D, got ", + fc1_experts_weights_dims.size()); + } + if (fc2_experts_weights_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_weights_dims must be 3D, got ", + fc2_experts_weights_dims.size()); + } + if (fc1_experts_weights_dims[1] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_weights_dims[1] must be equal to hidden_size, got ", + fc1_experts_weights_dims[1], " and ", hidden_size); + } + if (fc2_experts_weights_dims[1] != inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_weights_dims[1] must be equal to inter_size, got ", fc2_experts_weights_dims[1], + " and ", inter_size); + } + if (fc1_experts_weights_dims[2] != inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_weights_dims[2] must be equal to inter_size, got ", fc1_experts_weights_dims[2], + " and ", inter_size); + } + if (fc2_experts_weights_dims[2] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_weights_dims[2] must be equal to hidden_size, got ", + fc2_experts_weights_dims[2], " and ", hidden_size); + } + if (router_probs_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims must be 2D, got ", + router_probs_dims.size()); + } + if (router_probs_dims[0] != num_rows) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[0] must be equal to num_rows, got ", + router_probs_dims[0], " and ", num_rows); + } + if (router_probs_dims[1] != num_experts) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[1] must be equal to num_experts, got ", + router_probs_dims[1], " and ", num_experts); + } + if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is set but fc2_experts_bias is not set"); + } + if (fc1_experts_bias_optional == nullptr && fc2_experts_bias_optional != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is not set but fc2_experts_bias is set"); + } + if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional != nullptr) { + const auto& fc1_experts_bias_dims = fc1_experts_bias_optional->Shape().GetDims(); + const auto& fc2_experts_bias_dims = fc2_experts_bias_optional->Shape().GetDims(); + if (fc1_experts_bias_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims must be 2D, got ", + fc1_experts_bias_dims.size()); + } + if (fc2_experts_bias_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims must be 2D, got ", + fc2_experts_bias_dims.size()); + } + if (fc1_experts_bias_dims[0] != num_experts) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_bias_dims[0] must be equal to num_experts, got ", fc1_experts_bias_dims[0], + " and ", num_experts); + } + if (fc2_experts_bias_dims[0] != num_experts) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_bias_dims[0] must be equal to num_experts, got ", fc2_experts_bias_dims[0], + " and ", num_experts); + } + if (fc1_experts_bias_dims[1] != inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_bias_dims[1] must be equal to inter_size, got ", fc1_experts_bias_dims[1], + " and ", inter_size); + } + if (fc2_experts_bias_dims[1] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_bias_dims[1] must be equal to hidden_size, got ", fc2_experts_bias_dims[1], + " and ", hidden_size); + } + } + + typedef typename ToCudaType::MappedType CudaT; + auto stream = context->GetComputeStream(); + + auto& device_prop = GetDeviceProp(); + const int sm = device_prop.major * 10 + device_prop.minor; + + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm); + + size_t ws_size = + moe_runner.getWorkspaceSize(static_cast(num_rows), static_cast(hidden_size), + static_cast(inter_size), static_cast(num_experts), static_cast(k_)); + size_t fc2_output_size = k_ * num_rows * hidden_size * sizeof(CudaT); + size_t expert_scales_size = k_ * num_rows * sizeof(CudaT); + size_t expanded_source_row_to_expanded_dest_row_size = k_ * num_rows * sizeof(int); + size_t expert_for_source_row_size = k_ * num_rows * sizeof(int); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + // TODO: allocate one buffer and reuse it. + IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, ws_size, false, stream); + IAllocatorUniquePtr fc2_output = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, stream); + IAllocatorUniquePtr expert_scales = + IAllocator::MakeUniquePtr(allocator, expert_scales_size, false, stream); + IAllocatorUniquePtr expanded_source_row_to_expanded_dest_row = + IAllocator::MakeUniquePtr(allocator, expanded_source_row_to_expanded_dest_row_size, false, stream); + IAllocatorUniquePtr expert_for_source_row = + IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream); + + // fc1_scales and fc2_scales are used in quantized MoE + const CudaT* fc1_scales_ptr = nullptr; + const CudaT* fc2_scales_ptr = nullptr; + + moe_runner.run_moe_fc(reinterpret_cast(input->template Data()), + reinterpret_cast(router_probs->template Data()), + reinterpret_cast(fc1_experts_weights->template Data()), + std::move(fc1_scales_ptr), + fc1_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc1_experts_bias_optional->template Data()), + activation_type_, reinterpret_cast(fc2_experts_weights->template Data()), + std::move(fc2_scales_ptr), static_cast(num_rows), static_cast(hidden_size), + static_cast(inter_size), static_cast(num_experts), static_cast(k_), + reinterpret_cast(work_space.get()), reinterpret_cast(fc2_output.get()), + reinterpret_cast(expert_scales.get()), + reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), + reinterpret_cast(expert_for_source_row.get()), Stream(context)); + + Tensor* output = context->Output(0, input->Shape()); + + ort_fastertransformer::finalize_moe_routing_kernelLauncher( + reinterpret_cast(fc2_output.get()), reinterpret_cast(output->template MutableData()), + fc2_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc2_experts_bias_optional->template Data()), + reinterpret_cast(expert_scales.get()), + reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), + reinterpret_cast(expert_for_source_row.get()), static_cast(num_rows), static_cast(hidden_size), + static_cast(k_), Stream(context)); + + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.h b/onnxruntime/contrib_ops/cuda/moe/moe.h new file mode 100644 index 0000000000000..8035568693814 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/moe.h @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class MoE final : public CudaKernel { + public: + explicit MoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("k", &k_).IsOK()); + + std::string activation_type_str; + ORT_ENFORCE(op_kernel_info.GetAttr("activation_type", &activation_type_str).IsOK()); + if (activation_type_str == "relu") { + activation_type_ = ort_fastertransformer::ActivationType::Relu; + } else if (activation_type_str == "gelu") { + activation_type_ = ort_fastertransformer::ActivationType::Gelu; + } else if (activation_type_str == "silu") { + activation_type_ = ort_fastertransformer::ActivationType::Silu; + } else if (activation_type_str == "identity") { + activation_type_ = ort_fastertransformer::ActivationType::Identity; + } else { + ORT_THROW("Unsupported MoE activation type: ", activation_type_str); + } + } + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + int64_t k_; + ort_fastertransformer::ActivationType activation_type_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/attention.cc b/onnxruntime/contrib_ops/js/bert/attention.cc new file mode 100644 index 0000000000000..723ff00aa815e --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/attention.cc @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "attention.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + Attention, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + Attention); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/attention.h b/onnxruntime/contrib_ops/js/bert/attention.h new file mode 100644 index 0000000000000..0fa823befa9b2 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/attention.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/bert/attention_base.h" +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::contrib::AttentionBase; +using onnxruntime::js::JsKernel; + +class Attention : public JsKernel, AttentionBase { + public: + explicit Attention(const OpKernelInfo& info) : JsKernel(info), AttentionBase(info, false) { + std::vector qkv_sizes(qkv_hidden_sizes_.size()); + if (qkv_hidden_sizes_.size() > 0) { + std::transform(qkv_hidden_sizes_.begin(), qkv_hidden_sizes_.end(), qkv_sizes.begin(), + [](int64_t sz) { return gsl::narrow_cast(sz); }); + } + + JSEP_INIT_KERNEL_ATTRIBUTE(Attention, ({ + "numHeads" : $1, + "isUnidirectional" : $2, + "maskFilterValue" : $3, + "scale" : $4, + "doRotary" : $5, + "qkvHiddenSizes" : $6 ? (Array.from(HEAP32.subarray(Number($7), Number($7) + $6))) : [], + "pastPresentShareBuffer" : !!$8, + }), + static_cast(num_heads_), + static_cast(is_unidirectional_), + static_cast(mask_filter_value_), + static_cast(scale_), + static_cast(do_rotary_), + static_cast(qkv_hidden_sizes_.size()), + reinterpret_cast((qkv_sizes.size() > 0) ? qkv_sizes.data() : nullptr) >> 2, + static_cast(past_present_share_buffer_)); + } +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc b/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc new file mode 100644 index 0000000000000..c43f8b7f18465 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "multi_head_attention.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + MultiHeadAttention, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + MultiHeadAttention); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/multi_head_attention.h b/onnxruntime/contrib_ops/js/bert/multi_head_attention.h new file mode 100644 index 0000000000000..6c63a2ffed4b2 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/multi_head_attention.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/bert/attention_base.h" +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::contrib::AttentionBase; +using onnxruntime::js::JsKernel; + +class MultiHeadAttention : public JsKernel, AttentionBase { + public: + explicit MultiHeadAttention(const OpKernelInfo& info) : JsKernel(info), AttentionBase(info, false) { + JSEP_INIT_KERNEL_ATTRIBUTE(MultiHeadAttention, ({ + "numHeads" : $1, + "isUnidirectional" : $2, + "maskFilterValue" : $3, + "scale" : $4, + "doRotary" : $5, + }), + static_cast(num_heads_), + static_cast(is_unidirectional_), + static_cast(mask_filter_value_), + static_cast(scale_), + static_cast(do_rotary_)); + } +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 24d327576ecd9..498a9f5679eb5 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -7,7 +7,9 @@ namespace onnxruntime { namespace contrib { namespace js { +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); @@ -21,7 +23,9 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo // For std::cerr. + // Writing to stderr instead of logging because logger may not be initialized yet. namespace onnxruntime { @@ -137,7 +138,7 @@ void decodeMIDR( break; // #endif /* ARM */ default: - LOGS_DEFAULT(WARNING) << "unknown ARM CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown ARM CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } } break; @@ -156,7 +157,7 @@ void decodeMIDR( break; // #endif default: - LOGS_DEFAULT(WARNING) << "unknown Broadcom CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Broadcom CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #if (defined(_M_ARM64) || defined(__aarch64__)) && !defined(__ANDROID__) @@ -172,7 +173,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_thunderx2; break; default: - LOGS_DEFAULT(WARNING) << "unknown Cavium CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Cavium CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif @@ -187,7 +188,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_cortex_a76; break; default: - LOGS_DEFAULT(WARNING) << "unknown Huawei CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Huawei CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #if defined(_M_ARM) || defined(__arm__) @@ -199,7 +200,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_xscale; break; default: - LOGS_DEFAULT(WARNING) << "unknown Intel CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Intel CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif /* ARM */ @@ -215,7 +216,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_carmel; break; default: - LOGS_DEFAULT(WARNING) << "unknown Nvidia CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Nvidia CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; #if !defined(__ANDROID__) @@ -225,7 +226,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_xgene; break; default: - LOGS_DEFAULT(WARNING) << "unknown Applied Micro CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Applied Micro CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; #endif @@ -297,7 +298,7 @@ void decodeMIDR( break; // #endif /* ARM64 && !defined(__ANDROID__) */ default: - LOGS_DEFAULT(WARNING) << "unknown Qualcomm CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Qualcomm CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; case 'S': @@ -343,8 +344,9 @@ void decodeMIDR( *uarch = cpuinfo_uarch_exynos_m5; break; default: - LOGS_DEFAULT(WARNING) << "unknown Samsung CPU variant 0x" - << std::hex << midr_get_variant(midr) << " part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Samsung CPU variant 0x" + << std::hex << midr_get_variant(midr) << " part 0x" << std::hex << midr_get_part(midr) + << " ignored\n"; } break; // #if defined(_M_ARM) || defined(__arm__) @@ -355,12 +357,12 @@ void decodeMIDR( *uarch = cpuinfo_uarch_pj4; break; default: - LOGS_DEFAULT(WARNING) << "unknown Marvell CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Marvell CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif /* ARM */ default: - LOGS_DEFAULT(WARNING) << "unknown CPU uarch from MIDR value: 0x" << std::hex << midr; + std::cerr << "unknown CPU uarch from MIDR value: 0x" << std::hex << midr << "\n"; } } diff --git a/onnxruntime/core/framework/op_node_proto_helper.cc b/onnxruntime/core/framework/op_node_proto_helper.cc index 38d67eb0e0c72..c3deb94300e78 100644 --- a/onnxruntime/core/framework/op_node_proto_helper.cc +++ b/onnxruntime/core/framework/op_node_proto_helper.cc @@ -182,7 +182,7 @@ ORT_DEFINE_GET_ATTRS_SPAN_SPECIALIZATION(float, floats) ORT_DEFINE_GET_ATTRS_SPAN_SPECIALIZATION(int64_t, ints) template -MUST_USE_RESULT Status OpNodeProtoHelper::GetAttrs(const std::string& name, TensorShapeVector& out) const { +Status OpNodeProtoHelper::GetAttrs(const std::string& name, TensorShapeVector& out) const { gsl::span span; Status status = this->GetAttrsAsSpan(name, span); if (status.IsOK()) { @@ -193,7 +193,7 @@ MUST_USE_RESULT Status OpNodeProtoHelper::GetAttrs(const std::string& na } template -MUST_USE_RESULT Status OpNodeProtoHelper::GetAttrsStringRefs( +Status OpNodeProtoHelper::GetAttrsStringRefs( const std::string& name, std::vector>& refs) const { const AttributeProto* attr = TryGetAttribute(name); diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index df3a7afebc176..df11fe8302aef 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -455,11 +455,10 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::GraphViewer& // utils::CopyOneInputAcrossDevices is happy. auto& input_map = session_state.GetInputNodeInfoMap(); - auto end_map = input_map.cend(); for (const auto& graph_input : graph_inputs) { const auto& name = graph_input->Name(); - if (input_map.find(name) == end_map) { + if (input_map.find(name) == input_map.cend()) { // dummy entry for an input that we didn't find a use of in the graph. log it in case that's a bug. // utils::CopyOneInputAcrossDevices will use the input OrtValue as is given we don't believe it's used anywhere. LOGS(session_state.Logger(), INFO) << (graph.IsSubgraph() ? "Subgraph" : "Graph") << " input with name " diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index dcde2ddeb8270..b97fb0d2899fc 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -991,7 +991,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( constexpr const char* GroupQueryAttention_ver1_doc = R"DOC( Group Query Self/Cross Attention. -Supports different number of heads for q and kv. +Supports different number of heads for q and kv. Only supports causal or local attention. )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( @@ -1004,10 +1004,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", AttributeProto::FLOAT, OPTIONAL_VALUE) - // .Attr("left_padding_last_token", - // "Copy last token to last index of buffer. Default is 0; 1 when true.", - // AttributeProto::INT, - // OPTIONAL_VALUE) + .Attr("local_window_size", + "left_window_size for local attention (like Mistral). Default value is -1 meaning unused.", + AttributeProto::INT, + static_cast(-1)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size)", @@ -1144,7 +1144,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OPTIONAL_VALUE) .Input(0, "input", - "3D tensor with shape (batch_size, sequence_length, hidden_size)", + "3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)", "T") .Input(1, "position_ids", @@ -1160,7 +1160,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T") .Output(0, "output", - "3D tensor with shape (batch_size, sequence_length, hidden_size)", + "tensor with same shape as input.", "T") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors") diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 39449bea6303a..db0b13b0e1d27 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1375,6 +1375,27 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, GreedySearchShapeInference(ctx); })); +constexpr const char* MoE_ver1_doc = R"DOC( + Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1, + GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, and Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) + usually uses top 32 experts. + )DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1, + OpSchema() + .SetDoc(MoE_ver1_doc) + .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) + .Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1)) + .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") + .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") + .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T") + .Input(3, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size)", "T") + .Input(4, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + .Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional) + .Output(0, "output", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or float16 tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); + ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1, OpSchema() .Input(0, "X", "input", "T") diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index b35cfc5d12f36..5eef1b33a24dd 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -83,6 +83,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulInteger16); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulFpQ4); #endif class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MaxpoolWithMask); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MoE); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MultiHeadAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupQueryAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); @@ -189,6 +190,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); #endif fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/mlas/.clang-format b/onnxruntime/core/mlas/.clang-format index 4a89ef98cf049..16ad8bd8a7234 100644 --- a/onnxruntime/core/mlas/.clang-format +++ b/onnxruntime/core/mlas/.clang-format @@ -2,10 +2,12 @@ BasedOnStyle: Google IndentWidth: 4 -ColumnLimit: 100 +# Setting ColumnLimit to 0 so developer choices about where to break lines are maintained. +# Developers are responsible for adhering to the 120 character maximum. +ColumnLimit: 0 +AlignAfterOpenBracket: BlockIndent AlwaysBreakAfterReturnType: TopLevel AlwaysBreakTemplateDeclarations: Yes BinPackParameters: false BreakBeforeBraces: Linux ... - diff --git a/onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h b/onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h new file mode 100644 index 0000000000000..7ea29eb091318 --- /dev/null +++ b/onnxruntime/core/mlas/inc/mlas_gemm_postprocessor.h @@ -0,0 +1,33 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + mlas_gemm_postprocessor.h + +Abstract: + + This module contains a base class for custom postprocessing following a + GEMM. + +--*/ + +#pragma once + +template +class MLAS_GEMM_POSTPROCESSOR +{ + public: + virtual void Process(T* C, /**< the address of matrix to process */ + size_t RangeStartM, /**< the start row index of matrix */ + size_t RangeStartN, /**< the start col index of matrix */ + size_t RangeCountM, /**< the element count per row to process */ + size_t RangeCountN, /**< the element count per col to process */ + size_t ldc /**< the leading dimension of matrix */ + ) const = 0; + + virtual ~MLAS_GEMM_POSTPROCESSOR() {} +}; diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index 7c7b729117e4a..316344ad8c214 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -21,6 +21,7 @@ Module Name: #pragma once #include "mlas.h" +#include "mlas_gemm_postprocessor.h" #include #include @@ -95,22 +96,6 @@ MlasQ4GemmUnPackB( ); -template -class MLAS_GEMM_POSTPROCESSOR -{ - public: - virtual void Process(T*, /**< the address of matrix to process */ - size_t, /**< the start row index of matrix */ - size_t, /**< the start col index of matrix */ - size_t, /**< the element count per row to process */ - size_t, /**< the element count per col to process */ - size_t /**< the leading dimension of matrix */ - ) const = 0; - - virtual ~MLAS_GEMM_POSTPROCESSOR() {} -}; - - /** * @brief Data parameters for Q4 GEMM routine * C = A * B + Bias @@ -241,7 +226,7 @@ MlasQ8Q4GemmBatch( * matrix shape [rows, columns], compute the shape of the * quantization parameter matrix [meta_rows, meta_cols] */ -template +template void MlasBlockwiseQuantMetaShape( int block_size, @@ -259,6 +244,7 @@ MlasBlockwiseQuantMetaShape( * is in column major layout, with bits packed on the column. * * @tparam T + * @tparam qbits * @param block_size * @param columnwise * @param rows @@ -266,7 +252,7 @@ MlasBlockwiseQuantMetaShape( * @param q_rows * @param q_cols */ -template +template void MlasBlockwiseQuantizedShape( int block_size, @@ -277,6 +263,32 @@ MlasBlockwiseQuantizedShape( int& q_cols ); +/** + * @brief Compute the sizes of the quantized data and quantization parameter buffers. + * + * @param qbits The bit width of each quantized value. + * @param block_size The number of quantized values in a block. + * @param columnwise Whether a block contains values from a matrix column (true) or row (false). + * @param rows Number of matrix rows. + * @param columns Number of matrix columns. + * @param[out] q_data_size_in_bytes The size in bytes of the quantized data. + * @param[out] q_scale_num_elements The size in elements of the scale quantization parameters. + * @param[out] q_zero_point_size_in_bytes The size in bytes of the zero point quantization parameters. Optional. + * + * If the qbits or block_size values are unsupported the output sizes will be zero. + */ +void MLASCALL +MlasBlockwiseQuantizedBufferSizes( + int qbits, + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +); + /** * @brief Blockwise 4 bits quantization, resulting elements and quantization diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h new file mode 100644 index 0000000000000..9620dd42d1da9 --- /dev/null +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -0,0 +1,79 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + mlas_qnbit.h + +Abstract: + + This module contains the public data structures and procedure prototypes + for blocked n-bit quantized GEMM. + + N-bit block quantization is used to compress weight tensors of large + language models. + +--*/ + +#pragma once + +#include "mlas.h" +#include "mlas_gemm_postprocessor.h" + +/** + * @brief Data parameters for float/n-bit quantized int GEMM routine. + */ +struct MLAS_SQNBIT_GEMM_DATA_PARAMS { + const float* A = nullptr; ///< address of A (float32 matrix) + size_t lda = 0; ///< leading dimension of A + const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values) + const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block + bool IsBPacked = false; ///< whether B values are packed in an optimized format for the computation + const float* Bias = nullptr; ///< optional address of Bias, vector size N + float* C = nullptr; ///< address of result matrix + size_t ldc = 0; ///< leading dimension of C + + ///< optional post processing to apply to result matrix + MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; +}; + +/** + * @brief Batched GEMM: C = A * B + Bias + * A must be a float32 matrix + * B must be a quantized and packed n-bit int matrix + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool optional thread pool to use + */ +void MLASCALL +MlasSQNBitGemmBatch( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkBitWidth, + size_t BlkLen, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool = nullptr +); + +/** + * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + */ +bool MLASCALL +MlasIsSQNBitGemmAvailable( + size_t BlkBitWidth, + size_t BlkLen +); diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index e0c2772cbb719..6c859e4e4f44b 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -890,14 +890,30 @@ extern const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchNeon; extern const MLAS_CONV_SYM_DISPATCH MlasConvSymU8DispatchDot; extern const MLAS_CONV_SYM_DISPATCH MlasConvSymS8DispatchDot; +// +// Quantized 8-bit integer/quantized 4-bit integer matrix/matrix multiply dispatch structure. +// + struct MLAS_Q8Q4GEMM_DISPATCH; extern const MLAS_Q8Q4GEMM_DISPATCH MlasQ8Q4GemmDispatchAvx512vnni; +// +// Float/quantized 4-bit integer matrix/matrix multiply dispatch structure. +// + struct MLAS_FPQ4GEMM_DISPATCH; extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512; +// +// Float/quantized n-bit integer matrix/matrix multiply dispatch structure. +// + +struct MLAS_SQNBIT_GEMM_DISPATCH; + +extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; + // // Quantized depthwise convolution kernels. // @@ -1029,6 +1045,8 @@ struct MLAS_PLATFORM { const MLAS_FPQ4GEMM_DISPATCH* FpQ4GemmDispatch{nullptr}; const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr}; + + const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr}; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 39586282e00ad..fec56c6ee063f 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -460,6 +460,7 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index fbd1030de8ab7..48d975a7fd26d 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -422,6 +422,24 @@ struct BlockwiseQuantizer { q_cols = meta_cols * QuantBlk::kColumn; } + static MLAS_FORCEINLINE void quantizedBufferSizes( + int rows, int columns, size_t& data_bytes, size_t& scale_num_elements, size_t* zero_point_bytes + ) + { + int meta_rows, meta_cols; + quantizeMetaShape(rows, columns, meta_rows, meta_cols); + int q_rows, q_cols; + quantizedShape(rows, columns, q_rows, q_cols); + + data_bytes = q_rows * q_cols; + scale_num_elements = meta_rows * meta_cols; + + if (zero_point_bytes) { + // this works for qbits == 4 but may need to be updated for other qbits values + *zero_point_bytes = ((meta_rows * qbits + 7) / 8) * meta_cols; + } + } + /** * @brief Quantized a Matrix shape [rows, columns], resulting quantized * and packed data are stored in column major (transposed) @@ -621,7 +639,7 @@ struct BlockwiseQuantizer { }; -template +template void MlasBlockwiseQuantMetaShape( int block_size, @@ -635,47 +653,47 @@ MlasBlockwiseQuantMetaShape( switch (block_size) { case 16: { if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } break; } case 32: { if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape( + BlockwiseQuantizer::quantizeMetaShape( rows, columns, meta_rows, meta_cols); } break; } case 64: { if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } break; } case 128: { if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } break; } case 256: { if (columnwise) { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } else { - BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, + BlockwiseQuantizer::quantizeMetaShape(rows, columns, meta_rows, meta_cols); } break; @@ -689,7 +707,7 @@ MlasBlockwiseQuantMetaShape( -template +template void MlasBlockwiseQuantizedShape( int block_size, @@ -703,42 +721,42 @@ MlasBlockwiseQuantizedShape( switch (block_size) { case 16: { if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } else { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } break; } case 32: { if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } else { - BlockwiseQuantizer::quantizedShape( + BlockwiseQuantizer::quantizedShape( rows, columns, q_rows, q_cols); } break; } case 64: { if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } else { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } break; } case 128: { if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } else { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } break; } case 256: { if (columnwise) { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } else { - BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); + BlockwiseQuantizer::quantizedShape(rows, columns, q_rows, q_cols); } break; } @@ -752,7 +770,7 @@ MlasBlockwiseQuantizedShape( template void -MlasBlockwiseQuantMetaShape( +MlasBlockwiseQuantMetaShape( int block_size, bool columnwise, int rows, @@ -763,7 +781,7 @@ MlasBlockwiseQuantMetaShape( template void -MlasBlockwiseQuantizedShape( +MlasBlockwiseQuantizedShape( int block_size, bool columnwise, int rows, @@ -773,6 +791,93 @@ MlasBlockwiseQuantizedShape( ); +void MLASCALL +MlasBlockwiseQuantizedBufferSizes( + int qbits, + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +) +{ + q_data_size_in_bytes = q_scale_num_elements = 0; + if (q_zero_point_size_in_bytes) { + *q_zero_point_size_in_bytes = 0; + } + + if (qbits == 4) { + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 32: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 64: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 128: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 256: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; + } + } +} + + template void MlasQuantizeBlockwise( diff --git a/onnxruntime/core/mlas/lib/q4gemm.h b/onnxruntime/core/mlas/lib/q4gemm.h index 1562f9c0b4236..b1b51dd53c4fc 100644 --- a/onnxruntime/core/mlas/lib/q4gemm.h +++ b/onnxruntime/core/mlas/lib/q4gemm.h @@ -90,7 +90,7 @@ MlasQ4GemmOperation( if (DataParams->OutputProcessor != nullptr) { DataParams->OutputProcessor->Process( - DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN + n, RowsHandled, CountN, ldc); } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp new file mode 100644 index 0000000000000..f964b1affec31 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -0,0 +1,144 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch. +--*/ + +#include "sqnbitgemm.h" + +namespace +{ + +// Get quantization variant based on `BlkBitWidth` and `BlkLen`. +// Return -1 if the input values are unsupported. +int32_t +GetDispatchQuantVariant(size_t BlkBitWidth, size_t BlkLen) +{ + int32_t type = -1; + if (BlkBitWidth == 4 && BlkLen == 16) { + type = QuantVariant_BitWidth4_BlockSize16; + } else if (BlkBitWidth == 4 && BlkLen == 32) { + type = QuantVariant_BitWidth4_BlockSize32; + } else if (BlkBitWidth == 4 && BlkLen == 64) { + type = QuantVariant_BitWidth4_BlockSize64; + } else if (BlkBitWidth == 4 && BlkLen == 128) { + type = QuantVariant_BitWidth4_BlockSize128; + } else if (BlkBitWidth == 4 && BlkLen == 256) { + type = QuantVariant_BitWidth4_BlockSize256; + } + + return type; +} + +} // namespace + +void MLASCALL +MlasSQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool +) +{ + const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen); + MLAS_SQNBIT_GEMM_OPERATION* const Operation = GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant]; + + if (ThreadPool == nullptr) { + for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { + auto Data = &DataParams[gemm_i]; + Operation(K, Data, 0, M, 0, N); + } + return; + } + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K) * double(BatchN); + + ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool) * 8; + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN; + if (ThreadsPerGemm < 1) { + ThreadsPerGemm = 1; + } + + constexpr size_t StrideM = 128; + + size_t nc = N; + if (ThreadsPerGemm > 1) { + // more than one thread per GEMM + + const size_t BlockedM = MlasDivRoundup(M, StrideM); + const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); + if (max_nc < nc) { + nc = std::min( + nc, MlasDivRoundup(max_nc, MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * + MLAS_QGEMM_STRIDEN_THREAD_ALIGN + ); + } + } + const size_t StrideN = nc; + + const size_t ThreadCountM = MlasDivRoundup(M, StrideM); + const size_t ThreadCountN = MlasDivRoundup(N, StrideN); + ThreadsPerGemm = ThreadCountM * ThreadCountN; + + MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { + const auto gemm_i = tid / ThreadsPerGemm; + const auto blk_i = tid % ThreadsPerGemm; + auto Data = &DataParams[gemm_i]; + + const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; + const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; + + const size_t RangeStartM = ThreadIdM * StrideM; + const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); + + const size_t RangeStartN = ThreadIdN * StrideN; + const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); + + Operation(K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + }); +} + +bool MLASCALL +MlasIsSQNBitGemmAvailable( + size_t BlkBitWidth, + size_t BlkLen +) +{ + const int32_t QuantVariant = GetDispatchQuantVariant(BlkBitWidth, BlkLen); + if (QuantVariant == -1) { + return false; + } + + if (GetMlasPlatform().SQNBitGemmDispatch == nullptr || + GetMlasPlatform().SQNBitGemmDispatch->Operations[QuantVariant] == nullptr) { + return false; + } + + return true; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h new file mode 100644 index 0000000000000..f8f7dcd43699f --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -0,0 +1,287 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm.h + +Abstract: + + This module includes: + + - Declaration of the set of template functions used to implement a kernel + for a matrix/matrix multiplication, A*B, where A is a float matrix and B is + a n-bit quantized integer matrix (QNBitGemm). + + - A shared kernel driver function template, MlasSQNBitGemmOperation. + + - Kernel dispatch structure. + + The B matrix is block quantized, which means that its values are grouped + into blocks which each have one scale and optional zero point. Each + quantized value in B is n-bits wide. + +--*/ + +#pragma once + +#include "mlas_qnbit.h" +#include "mlasi.h" + +// +// Kernel implementation template declarations +// + +/** + * @brief Multiply float matrix A with quantized n-bit integer matrix B. + * B is block quantized and column major. + * This kernel handles the special case where M, the number of rows of A and C, is 1. + * + * @tparam BlkBitWidth Bit width of each value in a block. + * @tparam BlkLen Number of values in a block. + * @tparam KernelType Hardware-specific kernel type. + * + * @param A Supplies the A matrix. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + */ +template +MLAS_FORCEINLINE void +MlasSQNBitGemmM1Kernel( + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +); + +/** + * @brief Dequantize B into the format expected by the Sgemm kernel. + * B is block quantized and column major. + * This is equivalent to dequantizing B and then running + * MlasSgemmCopyPackB. + * + * @tparam BlkBitWidth Bit width of each value in a block. + * @tparam BlkLen Number of values in a block. + * @tparam KernelType Hardware-specific kernel type. + * + * @param[out] FpData Supplies the output buffer for the dequantized B float data. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param CountN Number of columns of B. + * @param CountK Number of rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + */ +template +MLAS_FORCEINLINE void +MlasQNBitBlkDequantBForSgemm( + float* FpData, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB +); + +// +// MlasQNBitGemmOperation and helpers +// + +constexpr MLAS_FORCEINLINE size_t +MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) +{ + return BlkLen * BlkBitWidth / 8; +} + +template +constexpr MLAS_FORCEINLINE size_t +MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) +{ + if constexpr (BlkBitWidth <= 4) { + return MlasDivRoundup(BlkCount, 2); // 2 blocks per byte + } else { + return BlkCount; + } +} + +MLAS_FORCEINLINE void +MlasAddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc) +{ + for (size_t m = 0; m < CountM; m++) { + const float* bias = Bias; + float* sum = C; + for (size_t n = 0; n < CountN; n += 4) { + if (CountN - n < 4) { + for (size_t nn = n; nn < CountN; nn++) { + *sum += *bias; + sum++; + bias++; + } + break; + } + + MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum); + acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias)); + MlasStoreFloat32x4(sum, acc_x); + bias += 4; + sum += 4; + } + C += ldc; + } +} + +template +MLAS_FORCEINLINE void MLASCALL +MlasSQNBitGemmOperation( + const size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const float* A = DataParams->A + RangeStartM * lda; + + const uint8_t* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const uint8_t* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; + + if (RangeCountM == 1) { + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const float* a_row = A; + const uint8_t* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const uint8_t* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + MlasSQNBitGemmM1Kernel( + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } + return; + } + + constexpr size_t StrideN = 32; + size_t bufsize = k_blks * BlkLen * StrideN * sizeof(float); + MlasThreadedBufAlloc(bufsize); + auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); + // + // Step through each slice of matrix B along the N dimension. + // + + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, StrideN); + + // + // Step through each slice of matrix A along the M dimension. + // + const float* a_row = A; + const uint8_t* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const uint8_t* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + MlasQNBitBlkDequantBForSgemm( + dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks + ); + + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true + ); +#else + auto RowsHandled = MlasSgemmKernelZero(a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f); +#endif + + if (bias) { + MlasAddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); + } + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + RangeCountM - RowsRemaining, RangeStartN, + RowsHandled, CountN, ldc + ); + } + + c_blk += ldc * RowsHandled; + a_row += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } + } +} + +// +// Kernel dispatch structure. +// + +typedef void(MLASCALL MLAS_SQNBIT_GEMM_OPERATION)( + size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + size_t RangeStartM, + size_t RangeCountM, + size_t RangeStartN, + size_t RangeCountN +); + +enum QuantVariant { + QuantVariant_BitWidth4_BlockSize16, + QuantVariant_BitWidth4_BlockSize32, + QuantVariant_BitWidth4_BlockSize64, + QuantVariant_BitWidth4_BlockSize128, + QuantVariant_BitWidth4_BlockSize256, + QuantVariantCount, // Keep this element last and ensure that its value is the number of other QuantVariant values. + // Its value is used as an array size. +}; + +struct MLAS_SQNBIT_GEMM_DISPATCH { + MLAS_SQNBIT_GEMM_OPERATION* Operations[QuantVariantCount] = { + // Initialized to nullptrs. Overwrite in hardware-specific kernel implementation. + }; +}; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp new file mode 100644 index 0000000000000..63afe57dd9137 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -0,0 +1,489 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon.h + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON. + +--*/ + +#include "sqnbitgemm.h" + +#include + +#include +#include +#include + +// +// Hardware-specific kernel type. +// +struct MLAS_SQNBIT_GEMM_KERNEL_NEON { +}; + +namespace +{ + +template +MLAS_FORCEINLINE void +UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) +{ + (f(Indices), ...); +} + +template +MLAS_FORCEINLINE void +UnrolledLoop(IterationFn&& f) +{ + UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); +} + +MLAS_FORCEINLINE float32x4_t +FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) +{ + // aN: aN_0 aN_1 aN_2 aN_3 + + float32x4_t b0 = vzip1q_f32(a0, a1); // a0_0 a1_0 a0_1 a1_1 + float32x4_t b1 = vzip2q_f32(a0, a1); // a0_2 a1_2 a0_3 a1_3 + float32x4_t b2 = vzip1q_f32(a2, a3); // a2_0 a3_0 a2_1 a3_1 + float32x4_t b3 = vzip2q_f32(a2, a3); // a2_2 a3_2 a2_3 a3_3 + + // a0_0 a1_0 a2_0 a3_0 + a0 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); + // a0_1 a1_1 a2_1 a3_1 + a1 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); + // a0_2 a1_2 a3_2 a3_2 + a2 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); + // a0_3 a1_3 a2_3 a3_3 + a3 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); + + return vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3)); +} + +template +MLAS_FORCEINLINE void +LoadData(const float* src, size_t count, float32x4_t (& dst)[Capacity / 4]) +{ + static_assert(Capacity % 4 == 0, "Capacity must be divisible by 4."); + + assert(count <= Capacity); + + size_t vi = 0; // vector index + + // handle 4 values at a time + while (count > 3) { + dst[vi] = vld1q_f32(src); + + vi += 1; + src += 4; + count -= 4; + } + + // handle remaining values + if (count > 0) { + dst[vi] = vsetq_lane_f32(src[0], dst[vi], 0); + + if (count > 1) { + dst[vi] = vsetq_lane_f32(src[1], dst[vi], 1); + + if (count > 2) { + dst[vi] = vsetq_lane_f32(src[2], dst[vi], 2); + } + } + } +} + +template +MLAS_FORCEINLINE void +ComputeDotProducts( + const float* ARowPtr, + const uint8_t* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const uint8_t* QuantBZeroPointColPtr, + float* SumPtr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* BiasPtr +) +{ + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); + + const uint8x8_t LowMask = vdup_n_u8(0x0F); + + // Manual conversion to float takes place in two steps: + // 1. Map 4-bit values from [0, 15] to float values from [16.0f, 31.0f]. + // This target float range is convenient because the 4-bit source values can be placed directly into the + // target float bits. + // 2. Subtract the conversion offset of 16 from the float result. + + // The high 16 bits of an IEEE 754 32-bit float used as a template for creating float values. + constexpr uint16_t float_high_half_template = 0b0'10000011'0000000; + // sign|exponent|partial mantissa + // +|131: 2^4|~~~~ <- 4 bits go here + + const uint16x8_t float_high_half_template_v = vdupq_n_u16(float_high_half_template); + + float32x4_t acc[NCols]{}; + + const uint8_t* QuantBData = QuantBDataColPtr; + const float* QuantBScale = QuantBScaleColPtr; + size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); + + float scale[NCols]; + UnrolledLoop( + [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } + ); + + float offset[NCols]; // Includes zero point and float conversion offset of 16. + if (QuantBZeroPointColPtr != nullptr) { + UnrolledLoop([&](size_t i) { + const uint8_t zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const uint8_t zp = ((QuantBZeroPointIdx & 1) == 1) ? (zp_packed >> 4) : (zp_packed & 0x0F); + offset[i] = 16.0f + zp; + }); + } else { + UnrolledLoop([&](size_t i) { + constexpr float zp = 8.0f; + offset[i] = 16.0f + zp; + }); + } + + constexpr size_t SubBlkLen = 16; // number of block elements to process in one iteration + + for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { + // load A row vector elements + + // load `SubBlkLen` elements from A, padded with 0's if there aren't enough + const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, SubBlkLen); + float32x4_t av[4]{}; + LoadData(ARowPtr + k + k_idx_in_blk, k_subblk_len, av); + + // load B column vectors + uint8x8_t bv_packed[NCols]; + UnrolledLoop([&](size_t i) { + const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; + bv_packed[i] = vld1_u8(QuantBData + i * StrideQuantBData + b_data_block_offset); + }); + + uint8x8_t bv_u8_unzipped[NCols][2]; + UnrolledLoop([&](size_t i) { + bv_u8_unzipped[i][0] = vand_u8(bv_packed[i], LowMask); + bv_u8_unzipped[i][1] = vand_u8(vshr_n_u8(bv_packed[i], 4), LowMask); + }); + + uint8x8_t bv_u8[NCols][2]; + UnrolledLoop([&](size_t i) { + bv_u8[i][0] = vzip1_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1]); + bv_u8[i][1] = vzip2_u8(bv_u8_unzipped[i][0], bv_u8_unzipped[i][1]); + }); + + // dequantize B + + // shift left 3 and widen to 16 bits + uint16x8_t bv_u16[NCols][2]; + UnrolledLoop([&](size_t i) { + constexpr int shift = 3; + bv_u16[i][0] = vshll_n_u8(bv_u8[i][0], shift); + bv_u16[i][1] = vshll_n_u8(bv_u8[i][1], shift); + }); + + // combine 4 bits with float high half template + UnrolledLoop([&](size_t i) { + bv_u16[i][0] = vorrq_u16(bv_u16[i][0], float_high_half_template_v); + bv_u16[i][1] = vorrq_u16(bv_u16[i][1], float_high_half_template_v); + }); + + // `SubBlkLen` floats of B + float32x4_t bv[NCols][4]; + + // shift left 16, widen to 32 bits, and reinterpret as float + UnrolledLoop([&](size_t i) { + constexpr int shift = 16; + bv[i][0] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][0]), shift)); + bv[i][1] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][0], shift)); + + bv[i][2] = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(bv_u16[i][1]), shift)); + bv[i][3] = vreinterpretq_f32_u32(vshll_high_n_u16(bv_u16[i][1], shift)); + }); + + // subtract float conversion offset (16) and zero point + UnrolledLoop([&](size_t i) { + const float32x4_t offset_v = vdupq_n_f32(offset[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + + // multiply by scale + UnrolledLoop([&](size_t i) { + const float32x4_t scale_v = vdupq_n_f32(scale[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vmulq_f32(bv[i][j], scale_v); }); + }); + + // c[m,n] += a[m,k] * b[k,n] + UnrolledLoop<4>([&](size_t j) { + UnrolledLoop([&](size_t i) { acc[i] = vfmaq_f32(acc[i], av[j], bv[i][j]); }); + }); + } + + // increment pointers to next block + QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + QuantBScale += 1; + QuantBZeroPointIdx += 1; + } + + if constexpr (NCols == 4) { + float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + + if (BiasPtr != nullptr) { + sum = vaddq_f32(sum, vld1q_f32(BiasPtr)); + } + + vst1q_f32(SumPtr, sum); + } else { + for (size_t i = 0; i < NCols; ++i) { + SumPtr[i] = vaddvq_f32(acc[i]); + if (BiasPtr != nullptr) { + SumPtr[i] += BiasPtr[i]; + } + } + } +} + +} // namespace + +// +// MlasSQNBitGemmKernel and helpers. +// + +template +MLAS_FORCEINLINE void +MlasSQNBitGemmM1KernelNeon( + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t NCols = 4; + + const float* ARowPtr = A; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const uint8_t* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols; + + while (nblk >= 0) { + ComputeDotProducts( + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next `NCols` columns + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; + if (QuantBZeroPointColPtr != nullptr) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + + nblk -= NCols; + } + + // left over columns less than `NCols`? + nblk += NCols; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts( + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if (QuantBZeroPointColPtr != nullptr) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +#define SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(BlkBitWidth, BlkLen) \ + template <> \ + MLAS_FORCEINLINE void \ + MlasSQNBitGemmM1Kernel( \ + const float* A, \ + const uint8_t* QuantBData, \ + const float* QuantBScale, \ + const uint8_t* QuantBZeroPoint, \ + float* C, \ + size_t CountN, \ + size_t CountK, \ + size_t BlockStrideQuantB, \ + const float* Bias \ + ) \ + { \ + return MlasSQNBitGemmM1KernelNeon( \ + A, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, \ + BlockStrideQuantB, Bias \ + ); \ + } + +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 16) +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 32) +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 64) +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 128) +SPECIALIZE_SQNBIT_GEMM_M1_KERNEL(4, 256) + +#undef SPECIALIZE_SQNBIT_GEMM_M1_KERNEL + +// +// MlasQNBitBlkDequantBForSgemm and helpers. +// + +template +MLAS_FORCEINLINE void +MlasQNBitBlkDequantBForSgemmNeon( + float* FpData, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB +) +{ + auto impl0_reference = [&]() { + static_assert(BlkBitWidth == 4); + + float* Dst = FpData; + + const uint8_t* QuantBDataCol = QuantBData; + const float* QuantBScaleCol = QuantBScale; + const uint8_t* QuantBZeroPointCol = QuantBZeroPoint; + + for (size_t n = 0; n < CountN; n += 16) { + const size_t nnlen = std::min(CountN - n, size_t{16}); + + for (size_t nn = 0; nn < nnlen; ++nn) { + for (size_t k = 0, k_blk_idx = 0; k < CountK; k += BlkLen, k_blk_idx += 1) { + const size_t kklen = std::min(CountK - k, BlkLen); + + const uint8_t* b_data = + QuantBDataCol + k_blk_idx * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const float b_s = QuantBScaleCol[k_blk_idx]; + const uint8_t b_z = + (QuantBZeroPointCol != nullptr) + ? ((k_blk_idx & 1) == 1) + ? QuantBZeroPointCol[k_blk_idx / 2] >> 4 + : QuantBZeroPointCol[k_blk_idx / 2] & 0x0F + : 8; + + for (size_t kk = 0; kk < kklen; ++kk) { + const uint8_t b_packed = b_data[kk / 2]; + const uint8_t b_byte = ((kk & 1) == 1) ? b_packed >> 4 : b_packed & 0x0F; + const float b_value = (b_byte - b_z) * b_s; + + Dst[(k + kk) * 16 + nn] = b_value; + } + } + + QuantBDataCol += BlockStrideQuantB * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + QuantBScaleCol += BlockStrideQuantB; + if (QuantBZeroPointCol != nullptr) { + QuantBZeroPointCol += MlasQNBitZeroPointsForBlksSizeInBytes(BlockStrideQuantB); + } + } + + // zero out any remaining columns + + if (nnlen < 16) { + for (size_t k = 0; k < CountK; ++k) { + std::fill_n(Dst + (k * 16) + nnlen, 16 - nnlen, 0.0f); + } + } + + Dst += CountK * 16; + } + }; + + impl0_reference(); +} + +#define SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(BlkBitWidth, BlkLen) \ + template <> \ + MLAS_FORCEINLINE void \ + MlasQNBitBlkDequantBForSgemm( \ + float* FpData, \ + const uint8_t* QuantBData, \ + const float* QuantBScale, \ + const uint8_t* QuantBZeroPoint, \ + size_t CountN, \ + size_t CountK, \ + size_t BlockStrideQuantB \ + ) \ + { \ + MlasQNBitBlkDequantBForSgemmNeon( \ + FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB \ + ); \ + } + +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 16) +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 32) +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 64) +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 128) +SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM(4, 256) + +#undef SPECIALIZE_QNBIT_BLK_DEQUANT_B_FOR_SGEMM + +// +// Kernel dispatch structure definition. +// + +const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { + MLAS_SQNBIT_GEMM_DISPATCH d; + d.Operations[QuantVariant_BitWidth4_BlockSize16] = MlasSQNBitGemmOperation<4, 16, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + d.Operations[QuantVariant_BitWidth4_BlockSize32] = MlasSQNBitGemmOperation<4, 32, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + d.Operations[QuantVariant_BitWidth4_BlockSize64] = MlasSQNBitGemmOperation<4, 64, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + d.Operations[QuantVariant_BitWidth4_BlockSize128] = MlasSQNBitGemmOperation<4, 128, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + d.Operations[QuantVariant_BitWidth4_BlockSize256] = MlasSQNBitGemmOperation<4, 256, MLAS_SQNBIT_GEMM_KERNEL_NEON>; + return d; +}(); diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc index 094ea1e24dd92..9c98ed6d3e114 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc @@ -338,8 +338,8 @@ std::optional IsSupportedGather(Graph& graph, Node& node, auto axis = static_cast(node.GetAttributes().at("axis").i()); axis = axis < 0 ? axis + data_rank : axis; size_t dim_size = static_cast(indices_shape->dim_size()); - bool is_single_value_1d_tensor = dim_size != 0 && (dim_size == 1 && utils::HasDimValue(indices_shape->dim(0)) && - indices_shape->dim(0).dim_value() == 1); + bool is_single_value_1d_tensor = dim_size == 1 && utils::HasDimValue(indices_shape->dim(0)) && + indices_shape->dim(0).dim_value() == 1; if (dim_size != 0 && !is_single_value_1d_tensor) { if (dim_size == 1 && utils::HasDimValue(data_shape->dim(axis)) && data_shape->dim(axis).dim_value() > indices_shape->dim(0).dim_value()) { diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.cc index 716b027068ba1..23f7c45fba4ba 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.cc @@ -3,6 +3,7 @@ #ifdef ENABLE_TRAINING +#include #include "core/optimizer/utils.h" #include "core/optimizer/compute_optimizer/upstream_reshape_actors.h" @@ -282,6 +283,23 @@ bool LayerNormalizationReshapeActor::PreCheck( return propagate_input_indices.size() > 0; } +bool LayerNormalizationReshapeActor::PostProcess( + Graph& /* graph */, Node& current_node, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) { + auto axis = static_cast(current_node.GetAttributes().at("axis").i()); + // When Reshape(from 3D to 2D, with the first two dimensions be merged) upstream a LayerNormalization, + // The axis attribute of LayerNormalization should be decreased by 1 if it is greater than 1. + if (axis > 1) { + auto new_axis = axis - 1; + auto& attributes = current_node.GetMutableAttributes(); + attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", static_cast(new_axis)); + } + return true; +} + template class SimplePointwiseReshapeActor; template class SimplePointwiseReshapeActor; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.h b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.h index 05bcbabe9ba4c..de50a56fd8781 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.h +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.h @@ -111,13 +111,11 @@ class UpStreamReshapeOperatorActorBase : public UpStreamOperatorActorBase { * So far, we don't have requirements to override PostProcess function. */ - bool PostProcess(Graph& /* graph */, Node& /* current_node */, const ReshapeInfo& /* info_without_node */, - const logging::Logger& /* logger */, - std::vector& /* propagate_input_indices */, - const std::unordered_map>& /* all_input_cmp_rets */, - const std::unordered_map& /* new_reshape_infos */) { - return true; - } + virtual bool PostProcess(Graph& /* graph */, Node& /* current_node */, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) = 0; }; // The inputs are broad-cast-able. The outputs should have the same shape (fully broadcasted shape) @@ -133,6 +131,14 @@ class SimplePointwiseReshapeActor : public UpStreamReshapeOperatorActorBase { std::vector& propagate_input_indices, std::unordered_map>& all_input_cmp_rets, std::function& shape_update_func) override; + + bool PostProcess(Graph& /* graph */, Node& /* current_node */, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) override { + return true; + } }; class MatMulReshapeActor : public UpStreamReshapeOperatorActorBase { @@ -145,6 +151,14 @@ class MatMulReshapeActor : public UpStreamReshapeOperatorActorBase { std::vector& propagate_input_indices, std::unordered_map>& all_input_cmp_rets, std::function& shape_update_func) override; + + bool PostProcess(Graph& /* graph */, Node& /* current_node */, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) override { + return true; + } }; class LayerNormalizationReshapeActor : public UpStreamReshapeOperatorActorBase { @@ -157,6 +171,12 @@ class LayerNormalizationReshapeActor : public UpStreamReshapeOperatorActorBase { std::vector& propagate_input_indices, std::unordered_map>& all_input_cmp_rets, std::function& shape_update_func) override; + + bool PostProcess(Graph& /* graph */, Node& current_node, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) override; }; /** diff --git a/onnxruntime/core/optimizer/gather_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc index b994028cbca13..4903bc1d6b961 100644 --- a/onnxruntime/core/optimizer/gather_fusion.cc +++ b/onnxruntime/core/optimizer/gather_fusion.cc @@ -9,7 +9,8 @@ namespace onnxruntime { -bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const { +bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, + int64_t& indices_n_dims) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { return false; @@ -53,6 +54,22 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + InlinedVector node_args; + for (auto node_arg : graph.GetInputs()) { + if (node_arg && graph.GetConsumerNodes(node_arg->Name()).size() > 1) { + node_args.push_back(node_arg); + } + } + + for (auto entry : graph.GetAllInitializedTensors()) { + if (graph.GetConsumerNodes(entry.first).size() > 1) { + auto node_arg = graph.GetNodeArg(entry.first); + if (node_arg) { + node_args.push_back(node_arg); + } + } + } + for (auto node_index : node_topology_list) { auto* p_node = graph.GetNode(node_index); if (p_node == nullptr) continue; // we removed the node as part of an earlier fusion @@ -73,7 +90,11 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le size_t output_count = node.GetOutputEdgesCount(); if (output_count <= 1) continue; - auto shape = node.MutableOutputDefs()[0]->Shape(); + node_args.push_back(node.OutputDefs()[0]); + } + + for (const NodeArg* node_arg : node_args) { + auto shape = node_arg->Shape(); if (!shape) continue; int64_t rank = static_cast(shape->dim_size()); @@ -81,11 +102,14 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le bool first_edge = true; int64_t split_axis = 0; int64_t indices_n_dims = -1; - InlinedVector gather_outputs(output_count, nullptr); + auto consumers = graph.GetConsumerNodes(node_arg->Name()); + size_t consumer_count = consumers.size(); + InlinedVector gather_outputs(consumer_count, nullptr); InlinedVector> nodes_to_fuse; - for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) { + for (auto consumer : consumers) { int64_t index, axis, dims; - if (!IsSupportedGather(graph, *it, index, axis, dims)) { + if (!consumer || consumer->InputDefs()[0] != node_arg || + !IsSupportedGather(graph, *consumer, index, axis, dims)) { can_fuse = false; break; } @@ -99,7 +123,7 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (axis < 0) axis += rank; if (first_edge) { auto dim = shape->dim(static_cast(axis)); - if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast(output_count)) { + if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast(consumer_count)) { can_fuse = false; break; } @@ -109,12 +133,12 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le can_fuse = false; break; } - if (index < 0) index += static_cast(output_count); - if (index < 0 || index >= static_cast(output_count) || gather_outputs[static_cast(index)]) { + if (index < 0) index += static_cast(consumer_count); + if (index < 0 || index >= static_cast(consumer_count) || gather_outputs[static_cast(index)]) { can_fuse = false; break; } - Node& gather_node = *graph.GetNode(it->Index()); + Node& gather_node = *graph.GetNode(consumer->Index()); nodes_to_fuse.emplace_back(gather_node); gather_outputs[static_cast(index)] = gather_node.MutableOutputDefs()[0]; } @@ -122,8 +146,8 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (!can_fuse) continue; ONNX_NAMESPACE::TypeProto split_output_type; - const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( - node.MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()); + const ONNX_NAMESPACE::TensorProto_DataType element_type = + static_cast(node_arg->TypeAsProto()->tensor_type().elem_type()); split_output_type.mutable_tensor_type()->set_elem_type(element_type); for (int64_t i = 0; i < rank; ++i) { if (i == split_axis) { @@ -136,16 +160,17 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le InlinedVector split_outputs; bool add_squeeze_node = indices_n_dims == 0; if (add_squeeze_node) { - for (size_t i = 0; i < output_count; ++i) { + for (size_t i = 0; i < consumer_count; ++i) { split_outputs.emplace_back( &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type)); } } - Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", - {node.MutableOutputDefs()[0]}, add_squeeze_node ? split_outputs : gather_outputs); + Node& split_node = + graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", + {graph.GetNodeArg(node_arg->Name())}, add_squeeze_node ? split_outputs : gather_outputs); split_node.AddAttribute("axis", split_axis); - split_node.SetExecutionProviderType(node.GetExecutionProviderType()); + split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); // Squeeze-11, Squeee-13, Split-13, Split-18 have different schemas. int onnx_opset_version = -1; @@ -155,16 +180,16 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (onnx_opset_version < 13) { if (add_squeeze_node) { - for (size_t i = 0; i < output_count; ++i) { + for (size_t i = 0; i < consumer_count; ++i) { Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", "Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]}); squeeze_node.AddAttribute("axes", std::vector{split_axis}); - squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType()); + squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); } } } else { if (onnx_opset_version >= 18) { - split_node.AddAttribute("num_outputs", static_cast(output_count)); + split_node.AddAttribute("num_outputs", static_cast(consumer_count)); } if (add_squeeze_node) { @@ -176,11 +201,11 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t)); NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto); - for (size_t i = 0; i < output_count; ++i) { + for (size_t i = 0; i < consumer_count; ++i) { Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", "Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]}); - squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType()); + squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); } } } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 5015e48fdb7b8..3880288bdba2e 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -443,7 +443,6 @@ bool InstanceAndLayerNormalizationNodeGroupSelector::Check(const GraphViewer& gr } int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - int32_t dt_scale = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); int32_t dt_bias = 0; bool has_bias = false; // bias is optional for LayerNorm @@ -453,9 +452,9 @@ bool InstanceAndLayerNormalizationNodeGroupSelector::Check(const GraphViewer& gr } int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - // Input, output, and scale need to be the same type. The bias is int32. + // Input, output, need to be the same type. The bias is int32. + // Scale can be different with input for a16w8 case return (dt_input == dt_output) && - (dt_input == dt_scale) && (has_bias ? dt_bias == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32 : true); } diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 60e0b1c061a43..4a6743e9e5c52 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -169,6 +170,60 @@ Status CreateInputFeatureProvider(const std::unordered_map mlmultiarray_buffer_size) { + if (mlmultiarray_buffer == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "mlmultiarray_buffer has no data"); + } + + const size_t num_elements = array_info.count; + const auto onnx_data_type = tensor_info->data_type; + switch (onnx_data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + const auto output_data_byte_size = num_elements * sizeof(float); + ORT_RETURN_IF_NOT(!mlmultiarray_buffer_size || mlmultiarray_buffer_size == output_data_byte_size, + "CoreML output buffer size and expected output size differ"); + memcpy(tensor_buffer, mlmultiarray_buffer, output_data_byte_size); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT32: { + const auto output_data_byte_size = num_elements * sizeof(int32_t); + ORT_RETURN_IF_NOT(!mlmultiarray_buffer_size || mlmultiarray_buffer_size == output_data_byte_size, + "CoreML output buffer size and expected output size differ"); + memcpy(tensor_buffer, mlmultiarray_buffer, output_data_byte_size); + break; + } + // For this case, since Coreml Spec only uses int32 for model output while onnx provides + // int64 for model output data type. We are doing a type casting (int32 -> int64) here + // when copying the model to ORT + case ONNX_NAMESPACE::TensorProto_DataType_INT64: { + ORT_RETURN_IF_NOT(array_info.dataType == MLMultiArrayDataTypeInt32, + "CoreML output data type is not MLMultiArrayDataTypeInt32"); + ORT_RETURN_IF_NOT(!mlmultiarray_buffer_size || mlmultiarray_buffer_size == num_elements * sizeof(int32_t), + "CoreML output buffer size and expected output size differ"); + const auto model_output_span = gsl::span{static_cast(mlmultiarray_buffer), num_elements}; + const auto output_span = gsl::span{static_cast(tensor_buffer), num_elements}; + std::transform(model_output_span.begin(), model_output_span.end(), output_span.begin(), + [](int32_t v) { return static_cast(v); }); + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Output data type is not supported, actual type: ", onnx_data_type); + } + return Status::OK(); +} } // namespace NS_ASSUME_NONNULL_BEGIN @@ -298,9 +353,9 @@ - (Status)predict:(const std::unordered_map&)inputs return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "output_features has no value for ", output_name); } - auto* data = [output_value multiArrayValue]; + MLMultiArray* data = [output_value multiArrayValue]; - const auto coreml_static_output_shape = [&]() { + const auto coreml_static_output_shape = [data]() { InlinedVector result; result.reserve(data.shape.count); for (NSNumber* dim in data.shape) { @@ -324,41 +379,21 @@ - (Status)predict:(const std::unordered_map&)inputs ") do not match"); } - const void* model_output_buffer = data.dataPointer; - - if (model_output_buffer == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model_output_buffer has no data for ", output_name); - } - - const auto onnx_data_type = output_tensor_info.data_type; - switch (onnx_data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { - const auto output_data_byte_size = num_elements * sizeof(float); - memcpy(output_buffer, model_output_buffer, output_data_byte_size); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_INT32: { - const auto output_data_byte_size = num_elements * sizeof(int32_t); - memcpy(output_buffer, model_output_buffer, output_data_byte_size); - break; - } - // For this case, since Coreml Spec only uses int32 for model output while onnx provides - // int64 for model output data type. We are doing a type casting (int32 -> int64) here - // when copying the model to ORT - case ONNX_NAMESPACE::TensorProto_DataType_INT64: { - ORT_RETURN_IF_NOT(data.dataType == MLMultiArrayDataTypeInt32, - "CoreML output data type is not MLMultiArrayDataTypeInt32"); - - const auto model_output_span = gsl::span{static_cast(model_output_buffer), num_elements}; - const auto output_span = gsl::span{static_cast(output_buffer), num_elements}; - std::transform(model_output_span.begin(), model_output_span.end(), output_span.begin(), - [](int32_t v) { return static_cast(v); }); - break; - } - default: - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Output data type is not supported, actual type: ", onnx_data_type); + ORT_RETURN_IF_NOT(IsArrayContiguous(data), + "Non-contiguous output MLMultiArray is not currently supported"); + __block Status copy_status; + const auto* tensor_info = &output_tensor_info; + // `getBytesWithHandler` replaces deprecated `.dataPointer` on new versions + if (@available(macOS 12.3, iOS 15.4, *)) { + [data getBytesWithHandler:^(const void* bytes, NSInteger size) { + copy_status = CopyMLMultiArrayBuffer(bytes, output_buffer, data, tensor_info, size); + }]; + } else { + // disable size check as old API does not return buffer length + copy_status = CopyMLMultiArrayBuffer(data.dataPointer, output_buffer, data, tensor_info, std::nullopt); } + if (!copy_status.IsOK()) + return copy_status; } } } diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc index ce834e371fdef..3c83394fb0bf4 100644 --- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc @@ -688,21 +688,23 @@ FastReduceKind OptimizeShapeForFastReduce(gsl::span input_shape, return FastReduceKind::kNone; } -void ValidateCommonFastReduce(const Tensor* axes_tensor) { - ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); - ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, - "An axes tensor must be a vector tensor."); -} - // template bool CommonFastReduceCopy(OpKernelContext* ctx, TensorShapeVector& input_axes, bool noop_with_empty_axes) { if (ctx->InputCount() == 2) { // second input holds the axes. + // the argument is optional const Tensor* axes_tensor = ctx->Input(1); - ValidateCommonFastReduce(axes_tensor); - auto nDims = static_cast(axes_tensor->Shape()[0]); - const auto* data = axes_tensor->Data(); - input_axes.insert(input_axes.begin(), data, data + nDims); + + if (axes_tensor != nullptr) { + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a vector tensor."); + + const auto data_span = axes_tensor->DataAsSpan(); + input_axes.assign(data_span.begin(), data_span.end()); + } else { + input_axes.clear(); + } + if (input_axes.empty() && noop_with_empty_axes) { const Tensor* input = ctx->Input(0); auto* output = ctx->Output(0, input->Shape()); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 4ae59951c5e98..fdc5317419c5b 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -22,6 +22,11 @@ class SimpleOpBuilder : public BaseOpBuilder { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SimpleOpBuilder); protected: + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, @@ -48,6 +53,90 @@ class SimpleOpBuilder : public BaseOpBuilder { static constexpr std::array gridsample_supported_padding_modes = {"zeros", "border", "reflection"}; }; +// Move to qnn_utils if it's re-usable +Status InsertConvertOp(QnnModelWrapper& qnn_model_wrapper, + const std::string& convert_input_name, + const std::string& convert_output_name, + Qnn_DataType_t input_qnn_data_type, + Qnn_DataType_t output_qnn_data_type, + int32_t input_offset, + float input_scale, + const std::vector& output_shape, + bool do_op_validation) { + // Assume input is already handled. + float qmin = 0.0f; + float qmax = 255.0f; + ORT_RETURN_IF_ERROR(qnn::utils::GetQminQmax(input_qnn_data_type, qmin, qmax)); + double value_min = qnn::utils::Dequantize(input_offset, input_scale, qmin); + double value_max = qnn::utils::Dequantize(input_offset, input_scale, qmax); + + Qnn_QuantizeParams_t convert_output_quant_param = QNN_QUANTIZE_PARAMS_INIT; + convert_output_quant_param.encodingDefinition = QNN_DEFINITION_DEFINED; + convert_output_quant_param.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + ORT_RETURN_IF_ERROR(qnn::utils::GetQuantParams(static_cast(value_min), + static_cast(value_max), + output_qnn_data_type, + convert_output_quant_param.scaleOffsetEncoding.scale, + convert_output_quant_param.scaleOffsetEncoding.offset)); + + std::vector output_shape_copy = output_shape; + QnnTensorWrapper convert_output_tensorwrapper(convert_output_name, + QNN_TENSOR_TYPE_NATIVE, + output_qnn_data_type, + convert_output_quant_param, + std::move(output_shape_copy)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(convert_output_tensorwrapper)), "Failed to add tensor."); + + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(convert_output_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + "Convert", + {convert_input_name}, + {convert_output_name}, + {}, + do_op_validation), + "Failed to add node."); + return Status::OK(); +} + +Status SimpleOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const { + const std::string& op_type = node_unit.OpType(); + ORT_RETURN_IF_ERROR(BaseOpBuilder::ProcessInputs(qnn_model_wrapper, node_unit, logger, input_names, do_op_validation)); + + if (op_type == "MatMul") { + const auto& inputs = node_unit.Inputs(); + TensorInfo input0_info = {}; + TensorInfo input1_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[0], input0_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], input1_info)); + // Need to insert Convert op if both inputs are dynamic inputs and are ufixed_16 + if (!input0_info.is_initializer && !input1_info.is_initializer && + input0_info.qnn_data_type == input1_info.qnn_data_type && + input0_info.qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) { + // insert Convert op after input1 + std::string convert_input_name = input_names.back(); + input_names.pop_back(); + const std::string& matmul_output_name = node_unit.Outputs()[0].node_arg.Name(); + std::string convert_output_name = convert_input_name + "_convert_" + matmul_output_name; + ORT_RETURN_IF_ERROR(InsertConvertOp(qnn_model_wrapper, + convert_input_name, + convert_output_name, + input1_info.qnn_data_type, + QNN_DATATYPE_UFIXED_POINT_8, + input1_info.quant_param.scaleOffsetEncoding.offset, + input1_info.quant_param.scaleOffsetEncoding.scale, + input1_info.shape, + do_op_validation)); + input_names.push_back(convert_output_name); + } + } + + return Status::OK(); +} + Status SimpleOpBuilder::ExplicitOpCheck(const NodeUnit& node_unit) const { const std::string& op_type = node_unit.OpType(); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 020af451cdcd5..79f84864a5788 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1194,6 +1194,11 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() { } } + if (external_stream_) { + ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasDestroy(external_cublas_handle_))); + ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnDestroy(external_cudnn_handle_))); + } + if (!external_stream_ && stream_) { ORT_IGNORE_RETURN_VALUE(CUDA_CALL(cudaStreamDestroy(stream_))); } @@ -1272,6 +1277,20 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { return Status::OK(); } +// Get the pointer to the IBuilder instance. +// Note: This function is not thread safe. Calls to this function from different threads must be serialized +// even though it doesn't make sense to have multiple threads initializing the same inference session. +nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const { + if (!builder_) { + TensorrtLogger& trt_logger = GetTensorrtLogger(); + { + auto lock = GetApiLock(); + builder_ = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + } + } + return builder_.get(); +} + void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector& custom_op_domain_list) const { if (info_.custom_op_domain_list.empty()) { common::Status status = CreateTensorRTCustomOpDomainList(info_); @@ -1633,7 +1652,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect // Get supported node list recursively SubGraphCollection_t parser_nodes_list; TensorrtLogger& trt_logger = GetTensorrtLogger(); - auto trt_builder = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + auto trt_builder = GetBuilder(); const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(explicitBatch)); @@ -1810,6 +1829,10 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, if (sub_graphs.size() != 0) { bool all_subgraphs_are_supported = true; for (auto sub_graph : sub_graphs) { + // TRT EP should consider the empty subgraph is fully supported by TRT. + if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) { + continue; + } if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) { all_subgraphs_are_supported = false; break; @@ -1877,27 +1900,33 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, auto sub_graphs = graph.ParentNode()->GetSubgraphs(); for (auto sub_graph : sub_graphs) { if (sub_graph.get() != &graph.GetGraph()) { - auto sub_graph_veiwer = sub_graph->CreateGraphViewer(); - const int number_of_ort_subgraph_nodes = sub_graph_veiwer->NumberOfNodes(); + auto sub_graph_viewer = sub_graph->CreateGraphViewer(); + const int number_of_ort_subgraph_nodes = sub_graph_viewer->NumberOfNodes(); std::vector subgraph_nodes_vector(number_of_ort_subgraph_nodes); std::iota(std::begin(subgraph_nodes_vector), std::end(subgraph_nodes_vector), 0); SubGraphCollection_t parser_subgraph_nodes_vector = {{subgraph_nodes_vector, false}}; bool subgraph_early_termination = false; - // Another subgraph of "If" control flow has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP. - if (AllNodesAssignedToSpecificEP(*sub_graph_veiwer, kTensorrtExecutionProvider)) { + // Another subgraph of "If" control flow op has no nodes. + // In this case, TRT EP should consider this empty subgraph is fully supported by TRT. + if (sub_graph_viewer->NumberOfNodes() == 0) { + all_subgraphs_are_supported = true; + break; + } + // Another subgraph of "If" control flow op has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP. + else if (AllNodesAssignedToSpecificEP(*sub_graph_viewer, kTensorrtExecutionProvider)) { all_subgraphs_are_supported = true; break; } // Another subgraph of "If" control flow has been parsed by GetCapability and not all subgraph's nodes assigned to TRT EP. // (Note: GetExecutionProviderType() returns "" meaning node has not yet been assigned to any EPs) - else if (!AllNodesAssignedToSpecificEP(*sub_graph_veiwer, "")) { + else if (!AllNodesAssignedToSpecificEP(*sub_graph_viewer, "")) { all_subgraphs_are_supported = false; break; } // Another subgraph of "If" control flow has not yet been parsed by GetCapability. - subgraph_supported_nodes_vector = GetSupportedList(parser_subgraph_nodes_vector, 0, max_partition_iterations_, *sub_graph_veiwer, &subgraph_early_termination); + subgraph_supported_nodes_vector = GetSupportedList(parser_subgraph_nodes_vector, 0, max_partition_iterations_, *sub_graph_viewer, &subgraph_early_termination); all_subgraphs_are_supported = IsSubGraphFullySupported(subgraph_supported_nodes_vector, number_of_ort_subgraph_nodes); break; } @@ -1985,7 +2014,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(nvinfer1::createInferBuilder(trt_logger)); + auto trt_builder = GetBuilder(); const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(explicitBatch)); auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); @@ -2438,7 +2467,6 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorallocate_func, context->release_func, context->allocator_handle, context->node_name, - &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], + *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(), + &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, @@ -2490,7 +2518,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorsync_stream_after_enqueue; auto fused_node_name = trt_state->fused_node_name; auto& shape_ranges = trt_state->input_shape_ranges; - auto trt_builder = trt_state->builder->get(); + auto trt_builder = trt_state->builder; auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); auto trt_profiles = trt_state->profiles; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index cda08715ea009..a945d219088aa 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -105,10 +105,10 @@ struct TensorrtFuncState { DestroyFunc test_release_func = nullptr; AllocatorHandle allocator = nullptr; std::string fused_node_name; + nvinfer1::IBuilder* builder; tensorrt_ptr::unique_pointer* parser = nullptr; std::unique_ptr* engine = nullptr; std::unique_ptr* context = nullptr; - std::unique_ptr* builder = nullptr; std::unique_ptr* network = nullptr; std::vector> input_info; std::vector> output_info; @@ -245,6 +245,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; mutable std::unordered_map> subgraph_context_map_; + mutable std::unique_ptr builder_; + // Following maps that hold TRT objects will be accessible by different threads if ORT is using multithreading. // In general, TensorRT objects are not thread safe; accesses to an object from different threads must be serialized by the client. // But there are still some thread safe operations, please see here https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading @@ -456,5 +458,11 @@ class TensorrtExecutionProvider : public IExecutionProvider { void CaptureBegin(); void CaptureEnd(); void IncrementRegularRunCountBeforeGraphCapture(); + + /** + * Get the pointer to the IBuilder instance. + * This function only creates the instance at the first time it's being called." + */ + nvinfer1::IBuilder* GetBuilder() const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 46c456556e016..8ae16f0dd21fc 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -156,6 +156,7 @@ static const InlinedHashMap op_map = { {"GlobalMaxPool", "maxPool2d"}, {"GlobalLpPool", "l2Pool2d"}, {"Greater", "greater"}, + {"GreaterOrEqual", "greaterOrEqual"}, {"GroupNormalization", "meanVarianceNormalization"}, {"HardSigmoid", "hardSigmoid"}, {"HardSwish", "hardSwish"}, @@ -164,6 +165,7 @@ static const InlinedHashMap op_map = { {"LayerNormalization", "meanVarianceNormalization"}, {"LeakyRelu", "leakyRelu"}, {"Less", "lesser"}, + {"LessOrEqual", "lesserOrEqual"}, {"Log", "log"}, {"LpPool", "l2Pool2d"}, {"MatMul", "matmul"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index 4cb49d8f8cd3a..c8f58fa98635f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -35,8 +35,12 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons output = model_builder.GetBuilder().call("equal", input0, input1); } else if (op_type == "Greater") { output = model_builder.GetBuilder().call("greater", input0, input1); + } else if (op_type == "GreaterOrEqual") { + output = model_builder.GetBuilder().call("greaterOrEqual", input0, input1); } else if (op_type == "Less") { output = model_builder.GetBuilder().call("lesser", input0, input1); + } else if (op_type == "LessOrEqual") { + output = model_builder.GetBuilder().call("lesserOrEqual", input0, input1); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "LogicalOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); @@ -54,7 +58,9 @@ void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& { "Equal", "Greater", + "GreaterOrEqual", "Less", + "LessOrEqual", }; op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index 8778bb2414108..e48cf35012652 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -114,6 +114,22 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, if (!GetShape(*input_defs[0], input_shape, logger)) { return false; } + + if (input_defs.size() < 3) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 3 inputs (data, starts, ends) but got " + << input_defs.size(); + return false; + } + + // Inputs: starts, ends, axes, and steps must be constant initializers if present. + for (size_t i = 1; i < input_defs.size(); i++) { + if (!Contains(initializers, input_defs[i]->Name())) { + LOGS(logger, VERBOSE) << "Input [" << input_defs[i]->Name() << "] of " << op_type + << " [" << name << "] must be known as initializer"; + return false; + } + } + if (input_defs.size() == 5) { // Check steps. const auto& steps_tensor = *initializers.at(input_defs[4]->Name()); std::vector unpacked_tensor; @@ -140,18 +156,6 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } } - if (input_defs.size() < 3) { - LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 3 inputs (data starts and ends) but got " - << input_defs.size(); - return false; - } - - const auto& starts_name = input_defs[1]->Name(); - const auto& ends_name = input_defs[2]->Name(); - if (!Contains(initializers, starts_name) || !Contains(initializers, ends_name)) { - LOGS(logger, VERBOSE) << op_type << " [" << name << "] need starts and ends as initializer."; - return false; - } return true; } diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 65dc8ddbeaf90..463317a4dafda 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -99,7 +99,9 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { { // Logical CreateLogicalOpBuilder("Equal", op_registrations); CreateLogicalOpBuilder("Greater", op_registrations); + CreateLogicalOpBuilder("GreaterOrEqual", op_registrations); CreateLogicalOpBuilder("Less", op_registrations); + CreateLogicalOpBuilder("LessOrEqual", op_registrations); } { // Max/Min diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index ccedc71b9119a..f02d180ab104f 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2025,9 +2025,10 @@ common::Status InferenceSession::ValidateInputsOutputs(gsl::spansecond.tensor_shape.has_value()) { + const auto& opt_shape = iter->second.tensor_shape; + if (opt_shape.has_value() && !opt_shape->GetDims().empty()) { ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(name, input_output_tensor.Shape(), - *iter->second.tensor_shape, input_output_moniker)); + *opt_shape, input_output_moniker)); } } else if (input_output_ml_value.IsSparseTensor()) { #if !defined(DISABLE_SPARSE_TENSORS) @@ -2038,9 +2039,10 @@ common::Status InferenceSession::ValidateInputsOutputs(gsl::spansecond.tensor_shape.has_value()) { + const auto& opt_shape = iter->second.tensor_shape; + if (opt_shape.has_value() && !opt_shape->GetDims().empty()) { ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(name, sparse_tensor.DenseShape(), - *iter->second.tensor_shape, input_output_moniker)); + *opt_shape, input_output_moniker)); } } else if (is_sparse_initializer(name) && expected_type->IsTensorType()) { @@ -2049,9 +2051,10 @@ common::Status InferenceSession::ValidateInputsOutputs(gsl::spansecond.tensor_shape.has_value()) { + const auto& opt_shape = iter->second.tensor_shape; + if (opt_shape.has_value() && !opt_shape->GetDims().empty()) { ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(name, sparse_tensor.DenseShape(), - *iter->second.tensor_shape, input_output_moniker)); + *opt_shape, input_output_moniker)); } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, input_output_moniker, " with name: '", name, @@ -2061,7 +2064,6 @@ common::Status InferenceSession::ValidateInputsOutputs(gsl::spanIsTensorSequenceType() #if !defined(DISABLE_OPTIONAL_TYPE) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index a91ff91010e4b..a9cbef98d9165 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -157,6 +157,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "MemcpyFromHost": self._pass_on_shape_and_type, "MemcpyToHost": self._pass_on_shape_and_type, "Min": self._infer_symbolic_compute_ops, + "MoE": self._pass_on_shape_and_type, "Mul": self._infer_symbolic_compute_ops, "NonMaxSuppression": self._infer_NonMaxSuppression, "NonZero": self._infer_NonZero, diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index c1b241aa1a5ec..d11cb91d98b0c 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -657,7 +657,6 @@ def create_multihead_attention_node( return None graph_input_names = set([node.name for node in self.model.graph().input]) - graph_output_names = set([node.name for node in self.model.graph().output]) mha_node_name = self.model.create_node_name("Attention") # Add initial Q/K/V inputs for MHA @@ -693,12 +692,15 @@ def create_multihead_attention_node( mha_inputs.append("") # Add optional inputs for MHA - if past_k and past_v and past_k in graph_input_names and past_v in graph_input_names: + + if past_k and past_v: mha_inputs.extend([key_padding_mask, add_qk, past_k, past_v]) + elif key_padding_mask or add_qk: + mha_inputs.extend([key_padding_mask, add_qk]) # Add outputs for MHA mha_outputs = [output] - if present_k and present_v and present_k in graph_output_names and present_v in graph_output_names: + if present_k and present_v: mha_outputs.extend([present_k, present_v]) mha_node = helper.make_node( diff --git a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py new file mode 100644 index 0000000000000..6bc681c57444e --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py @@ -0,0 +1,143 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging + +from fusion_attention import AttentionMask, FusionAttention +from onnx_model import OnnxModel + +logger = logging.getLogger(__name__) + + +class FusionConformerAttention(FusionAttention): + """ + Fuse Conformer Attention subgraph into one MultiHeadAttention node. + """ + + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + attention_mask: AttentionMask, + ): + super().__init__(model, hidden_size, num_heads, attention_mask) + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + # SkipLayerNormalization has two inputs, and one of them is the root input for attention. + qkv_nodes = self.model.match_parent_path( + normalize_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 1, 0, 0, 0], + ) + if qkv_nodes is not None: + ( + _, + _, + reshape_qkv, + transpose_qkv, + matmul_qkv, + ) = qkv_nodes + else: + logger.debug("fuse_conformer_attention: failed to match qkv path") + return + + v_nodes = self.model.match_parent_path( + matmul_qkv, + ["Concat", "Transpose", "Reshape", "Add", "MatMul"], + [1, 1, 0, 0, 1], + ) + + add_v = None + if v_nodes is not None: + (concat_v, _, _, add_v, matmul_v) = v_nodes + concat_parent = self.model.get_parent(concat_v, 0, None) + present_v = concat_v.output[0] + past_v = concat_parent.output[0] + else: + logger.debug("fuse_conformer_attention: failed to match v path") + return + + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) + + if qk_nodes is not None: + _, add_qk, matmul_qk = qk_nodes + else: + logger.debug("fuse_conformer_attention: failed to match qk path") + return + + q_nodes = self.model.match_parent_path( + matmul_qk, + ["Div", "Transpose", "Reshape", "Add", "MatMul"], + [0, 0, 0, 0, 1], + ) + if q_nodes is not None: + _, _, reshape_q, add_q, matmul_q = q_nodes + else: + logger.debug("fuse_conformer_attention: failed to match q path") + return + + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Concat", "Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 1, 0, 0, 1], + ) + + matmul_k = None + if k_nodes is not None: + _, concat_k, _, _, add_k, matmul_k = k_nodes + concat_parent = self.model.get_parent(concat_k, 0, None) + past_k = concat_parent.output[0] + present_k = concat_k.output[0] + else: + logger.debug("fuse_conformer_attention: failed to match k path") + return + + attention_last_node = reshape_qkv + num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q) + + if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0: + logger.debug("fuse_conformer_attention: failed to detect num_heads or hidden_size") + return + + new_node = self.create_multihead_attention_node( + matmul_q, + matmul_k, + matmul_v, + add_q, + add_k, + add_v, + num_heads, + hidden_size, + attention_last_node.output[0], + add_qk=add_qk.input[1], + past_k=past_k, + past_v=past_v, + present_k=present_k, + present_v=present_v, + ) + + if new_node is None: + logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed") + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv]) + self.nodes_to_remove.extend(qk_nodes) + + # When using multihead attention, keep MatMul nodes in original graph + if q_nodes[-1].op_type == "MatMul": + q_nodes.pop() + if k_nodes[-1].op_type == "MatMul": + k_nodes.pop() + if v_nodes[-1].op_type == "MatMul": + v_nodes.pop() + + self.nodes_to_remove.extend(k_nodes) + self.nodes_to_remove.extend(v_nodes) + + # Use prune graph to remove mask nodes since they are shared by all attention nodes. + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 0955fcde6a672..1ec1ca3ba0c83 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -22,12 +22,6 @@ These optimizations are firstly carried out on CUDA EP. They may not work on oth | [optimize_pipeline.py](./optimize_pipeline.py) | Optimize Stable Diffusion ONNX models exported from Huggingface diffusers or optimum | | [benchmark.py](./benchmark.py) | Benchmark latency and memory of OnnxRuntime, xFormers or PyTorch 2.0 on stable diffusion. | -In some example, we run the scripts in source code directory. You can get source code like the following: - -``` -git clone https://github.com/microsoft/onnxruntime -cd onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion -``` ## Run demo with docker @@ -36,6 +30,7 @@ cd onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion git clone https://github.com/microsoft/onnxruntime cd onnxruntime ``` + #### Launch NVIDIA pytorch container Install nvidia-docker using [these instructions](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker). @@ -44,12 +39,6 @@ Install nvidia-docker using [these instructions](https://docs.nvidia.com/datacen docker run --rm -it --gpus all -v $PWD:/workspace nvcr.io/nvidia/pytorch:23.10-py3 /bin/bash ``` -Optionally, you can update TensorRT from 8.6.1 to latest pre-release. -``` -python3 -m pip install --upgrade pip -python3 -m pip install --pre --upgrade --extra-index-url https://pypi.nvidia.com tensorrt -``` - #### Build onnxruntime from source After launching the docker, you can build and install onnxruntime-gpu wheel like the following. ``` @@ -61,6 +50,7 @@ sh build.sh --config Release --build_shared_lib --parallel --use_cuda --cuda_ve --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF \ --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 \ --allow_running_as_root +python3 -m pip install --upgrade pip python3 -m pip install build/Linux/Release/dist/onnxruntime_gpu-1.17.0-cp310-cp310-linux_x86_64.whl --force-reinstall ``` @@ -83,7 +73,7 @@ python3 demo_txt2img_xl.py --help For example: `--engine {ORT_CUDA,ORT_TRT,TRT}` can be used to choose different backend engines including CUDA or TensorRT execution provider of ONNX Runtime, or TensorRT. -`--work-dir WORK_DIR` can be used to save models under a specified directory. +`--work-dir WORK_DIR` can be used to load or save models under the given directory. You can download the [optimized ONNX models of Stable Diffusion XL 1.0](https://huggingface.co/tlwu/stable-diffusion-xl-1.0-onnxruntime#usage-example) to save time in running the XL demo. #### Generate an image guided by a text prompt ```python3 demo_txt2img.py "astronaut riding a horse on mars"``` @@ -93,11 +83,12 @@ For example: If you do not provide prompt, the script will generate different image sizes for a list of prompts for demonstration. -It is recommended to use a machine with 64 GB or more memory to run this demo. +## Optimize Stable Diffusion ONNX models for Hugging Face Diffusers or Optimum -## Example of Stable Diffusion 1.5 or XL +If you are able to run the above demo with docker, you can use the docker and skip the following setup and fast forward to [Export ONNX pipeline](#export-onnx-pipeline). -Below is example to optimize Stable Diffusion 1.5 or XL in Linux. For Windows OS, please change the format of path to be like `.\sd` instead of `./sd`. +Below setup does not use docker. We'll use the environment to optimize ONNX models of Stable Diffusion exported by huggingface diffusers or optimum. +For Windows OS, please change the format of path to be like `.\sd` instead of `./sd`. It is recommended to create a Conda environment with Python 3.10 for the following setup: ``` @@ -217,7 +208,13 @@ Example to optimize the exported float32 ONNX models, and save to float16 models python -m onnxruntime.transformers.models.stable_diffusion.optimize_pipeline -i ./sd_v1_5/fp32 -o ./sd_v1_5/fp16 --float16 ``` -For SDXL model, it is recommended to use a machine with 32 GB or more memory to optimize. +In all examples below, we run the scripts in source code directory. You can get source code like the following: +``` +git clone https://github.com/microsoft/onnxruntime +cd onnxruntime/onnxruntime/python/tools/transformers/models/stable_diffusion +``` + +For SDXL model, it is recommended to use a machine with 48 GB or more memory to optimize. ``` python optimize_pipeline.py -i ./sd_xl_base_onnx -o ./sd_xl_base_fp16 --float16 ``` diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py index 57f7986503a5c..4636f139d4613 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -53,10 +53,31 @@ f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4" ) - min_image_size = 512 - max_image_size = 1024 if args.version in ["2.0", "2.1"] else 768 + # For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size. + # Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance. + # This range can cover common used shape of landscape 512x768, portrait 768x512, or square 512x512 and 768x768. + min_image_size = 512 if args.engine != "ORT_CUDA" else 256 + max_image_size = 768 if args.engine != "ORT_CUDA" else 1024 pipeline_info = PipelineInfo(args.version, min_image_size=min_image_size, max_image_size=max_image_size) - pipeline = init_pipeline(Txt2ImgPipeline, pipeline_info, engine_type, args, max_batch_size, batch_size) + + # Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to + # optimize the shape used most frequently. We can let user config it when we develop a UI plugin. + # In this demo, we optimize batch size 1 and image size 512x512 (or 768x768 for SD 2.0/2.1) for dynamic engine. + # This is mainly for benchmark purpose to simulate the case that we have no knowledge of user's preference. + opt_batch_size = 1 if args.build_dynamic_batch else batch_size + opt_image_height = pipeline_info.default_image_size() if args.build_dynamic_shape else args.height + opt_image_width = pipeline_info.default_image_size() if args.build_dynamic_shape else args.width + + pipeline = init_pipeline( + Txt2ImgPipeline, + pipeline_info, + engine_type, + args, + max_batch_size, + opt_batch_size, + opt_image_height, + opt_image_width, + ) if engine_type == EngineType.TRT: max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory()) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py index 364af9e8e9923..4f9ecf6cbb152 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -48,23 +48,56 @@ def load_pipelines(args, batch_size): # For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size. # Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance. + # This range can cover most frequent shape of landscape (832x1216), portrait (1216x832) or square (1024x1024). min_image_size = 832 if args.engine != "ORT_CUDA" else 512 max_image_size = 1216 if args.engine != "ORT_CUDA" else 2048 # No VAE decoder in base when it outputs latent instead of image. - base_info = PipelineInfo(args.version, use_vae=False, min_image_size=min_image_size, max_image_size=max_image_size) - base = init_pipeline(Txt2ImgXLPipeline, base_info, engine_type, args, max_batch_size, batch_size) + base_info = PipelineInfo( + args.version, use_vae=args.disable_refiner, min_image_size=min_image_size, max_image_size=max_image_size + ) - refiner_info = PipelineInfo( - args.version, is_refiner=True, min_image_size=min_image_size, max_image_size=max_image_size + # Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to + # optimize the shape used most frequently. We can let user config it when we develop a UI plugin. + # In this demo, we optimize batch size 1 and image size 1024x1024 for SD XL dynamic engine. + # This is mainly for benchmark purpose to simulate the case that we have no knowledge of user's preference. + opt_batch_size = 1 if args.build_dynamic_batch else batch_size + opt_image_height = base_info.default_image_size() if args.build_dynamic_shape else args.height + opt_image_width = base_info.default_image_size() if args.build_dynamic_shape else args.width + + base = init_pipeline( + Txt2ImgXLPipeline, + base_info, + engine_type, + args, + max_batch_size, + opt_batch_size, + opt_image_height, + opt_image_width, ) - refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type, args, max_batch_size, batch_size) + + refiner = None + if not args.disable_refiner: + refiner_info = PipelineInfo( + args.version, is_refiner=True, min_image_size=min_image_size, max_image_size=max_image_size + ) + refiner = init_pipeline( + Img2ImgXLPipeline, + refiner_info, + engine_type, + args, + max_batch_size, + opt_batch_size, + opt_image_height, + opt_image_width, + ) if engine_type == EngineType.TRT: - max_device_memory = max(base.backend.max_device_memory(), refiner.backend.max_device_memory()) + max_device_memory = max(base.backend.max_device_memory(), (refiner or base).backend.max_device_memory()) _, shared_device_memory = cudart.cudaMalloc(max_device_memory) base.backend.activate_engines(shared_device_memory) - refiner.backend.activate_engines(shared_device_memory) + if refiner: + refiner.backend.activate_engines(shared_device_memory) if engine_type == EngineType.ORT_CUDA: enable_vae_slicing = args.enable_vae_slicing @@ -72,7 +105,7 @@ def load_pipelines(args, batch_size): print("Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4.") enable_vae_slicing = True if enable_vae_slicing: - refiner.backend.enable_vae_slicing() + (refiner or base).backend.enable_vae_slicing() return base, refiner @@ -81,7 +114,8 @@ def run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False image_width = args.width batch_size = len(prompt) base.load_resources(image_height, image_width, batch_size) - refiner.load_resources(image_height, image_width, batch_size) + if refiner: + refiner.load_resources(image_height, image_width, batch_size) def run_base_and_refiner(warmup=False): images, time_base = base.run( @@ -93,8 +127,13 @@ def run_base_and_refiner(warmup=False): denoising_steps=args.denoising_steps, guidance=args.guidance, seed=args.seed, - return_type="latent", + return_type="latent" if refiner else "image", ) + if refiner is None: + return images, time_base + + # Use same seed in base and refiner. + seed = base.get_current_seed() images, time_refiner = refiner.run( prompt, @@ -105,7 +144,7 @@ def run_base_and_refiner(warmup=False): warmup=warmup, denoising_steps=args.denoising_steps, guidance=args.guidance, - seed=args.seed, + seed=seed, ) return images, time_base + time_refiner @@ -142,7 +181,8 @@ def run_demo(args): base, refiner = load_pipelines(args, batch_size) run_pipelines(args, base, refiner, prompt, negative_prompt) base.teardown() - refiner.teardown() + if refiner: + refiner.teardown() def run_dynamic_shape_demo(args): @@ -160,20 +200,20 @@ def run_dynamic_shape_demo(args): "blue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realistic", ] - # batch size, height, width, scheduler, steps, prompt + # batch size, height, width, scheduler, steps, prompt, seed configs = [ - (1, 832, 1216, "UniPC", 8, prompts[0]), - (1, 1024, 1024, "DDIM", 24, prompts[1]), - (1, 1216, 832, "EulerA", 18, prompts[2]), - (2, 1344, 768, "DDIM", 30, prompts[3]), - (2, 640, 1536, "UniPC", 18, prompts[4]), - (2, 1152, 896, "EulerA", 30, prompts[5]), + (1, 832, 1216, "UniPC", 8, prompts[0], None), + (1, 1024, 1024, "DDIM", 24, prompts[1], None), + (1, 1216, 832, "UniPC", 16, prompts[2], None), + (1, 1344, 768, "DDIM", 24, prompts[3], None), + (2, 640, 1536, "UniPC", 16, prompts[4], 4312973633252712), + (2, 1152, 896, "DDIM", 24, prompts[5], 1964684802882906), ] - # Warm up (for cudnn convolution algo search) once before serving. + # Warm up each combination of (batch size, height, width) once before serving. args.prompt = ["warm up"] args.num_warmup_runs = 1 - for batch_size, height, width, _, _, _ in configs: + for batch_size, height, width, _, _, _, _ in configs: args.batch_size = batch_size args.height = height args.width = width @@ -183,23 +223,26 @@ def run_dynamic_shape_demo(args): # Run pipeline on a list of prompts. args.num_warmup_runs = 0 - for batch_size, height, width, scheduler, steps, example_prompt in configs: + for batch_size, height, width, scheduler, steps, example_prompt, seed in configs: args.prompt = [example_prompt] args.batch_size = batch_size args.height = height args.width = width args.scheduler = scheduler args.denoising_steps = steps + args.seed = seed base.set_scheduler(scheduler) - refiner.set_scheduler(scheduler) + if refiner: + refiner.set_scheduler(scheduler) print( - f"\nbatch_size={batch_size}, height={height}, width={width}, scheduler={scheduler}, steps={steps}, prompt={example_prompt}" + f"\nbatch_size={batch_size}, height={height}, width={width}, scheduler={scheduler}, steps={steps}, prompt={example_prompt}, seed={seed}" ) prompt, negative_prompt = repeat_prompt(args) run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False) base.teardown() - refiner.teardown() + if refiner: + refiner.teardown() if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py index f28e5963b46d1..39ee273a3130d 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -68,7 +68,7 @@ def parse_arguments(is_xl: bool, description: str): "--scheduler", type=str, default="DDIM", - choices=["DDIM", "EulerA", "UniPC"], + choices=["DDIM", "UniPC"] if is_xl else ["DDIM", "EulerA", "UniPC"], help="Scheduler for diffusion process", ) @@ -145,6 +145,10 @@ def parse_arguments(is_xl: bool, description: str): parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.") parser.add_argument("--disable-cuda-graph", action="store_true", help="Disable cuda graph.") + parser.add_argument( + "--disable-refiner", action="store_true", help="Disable refiner and only run base for XL pipeline." + ) + group = parser.add_argument_group("Options for ORT_CUDA engine only") group.add_argument("--enable-vae-slicing", action="store_true", help="True will feed only one image to VAE once.") @@ -174,9 +178,9 @@ def parse_arguments(is_xl: bool, description: str): ) # Validate image dimensions - if args.height % 8 != 0 or args.width % 8 != 0: + if args.height % 64 != 0 or args.width % 64 != 0: raise ValueError( - f"Image height and width have to be divisible by 8 but specified as: {args.height} and {args.width}." + f"Image height and width have to be divisible by 64 but specified as: {args.height} and {args.width}." ) if (args.build_dynamic_batch or args.build_dynamic_shape) and not args.disable_cuda_graph: @@ -209,7 +213,9 @@ def repeat_prompt(args): return prompt, negative_prompt -def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_size, batch_size): +def init_pipeline( + pipeline_class, pipeline_info, engine_type, args, max_batch_size, opt_batch_size, opt_image_height, opt_image_width +): onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( work_dir=args.work_dir, pipeline_info=pipeline_info, engine_type=engine_type ) @@ -234,9 +240,6 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si engine_dir=engine_dir, framework_model_dir=framework_model_dir, onnx_dir=onnx_dir, - opt_image_height=args.height, - opt_image_width=args.height, - opt_batch_size=batch_size, force_engine_rebuild=args.force_engine_build, device_id=torch.cuda.current_device(), ) @@ -247,9 +250,9 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si framework_model_dir, onnx_dir, args.onnx_opset, - opt_image_height=args.height, - opt_image_width=args.height, - opt_batch_size=batch_size, + opt_image_height=opt_image_height, + opt_image_width=opt_image_width, + opt_batch_size=opt_batch_size, force_engine_rebuild=args.force_engine_build, static_batch=not args.build_dynamic_batch, static_image_shape=not args.build_dynamic_shape, @@ -264,9 +267,9 @@ def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_si framework_model_dir, onnx_dir, args.onnx_opset, - opt_batch_size=batch_size, - opt_image_height=args.height, - opt_image_width=args.height, + opt_batch_size=opt_batch_size, + opt_image_height=opt_image_height, + opt_image_width=opt_image_width, force_export=args.force_onnx_export, force_optimize=args.force_onnx_optimize, force_build=args.force_engine_build, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py index 81e5a0417c001..514205d3b8945 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -211,6 +211,13 @@ def min_image_size(self): def max_image_size(self): return self._max_image_size + def default_image_size(self): + if self.is_xl(): + return 1024 + if self.version in ("2.0", "2.1"): + return 768 + return 512 + class BaseModel: def __init__( diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py index 07c675b2ed990..a03ca7ce2912c 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_cuda.py @@ -159,9 +159,6 @@ def build_engines( framework_model_dir: str, onnx_dir: str, onnx_opset_version: int = 17, - opt_image_height: int = 512, - opt_image_width: int = 512, - opt_batch_size: int = 1, force_engine_rebuild: bool = False, device_id: int = 0, save_fp32_intermediate_model=False, @@ -209,7 +206,8 @@ def build_engines( with torch.inference_mode(): # For CUDA EP, export FP32 onnx since some graph fusion only supports fp32 graph pattern. - inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) + # Export model with sample of batch size 1, image size 512 x 512 + inputs = model_obj.get_sample_input(1, 512, 512) torch.onnx.export( model, diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py index 51b0966194189..e675c9a7b3bf5 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py @@ -109,7 +109,7 @@ def __init__( self.tokenizer = None self.tokenizer2 = None - self.generator = None + self.generator = torch.Generator(device="cuda") self.actual_steps = None self.current_scheduler = None @@ -181,8 +181,13 @@ def load_resources(self, image_height, image_width, batch_size): self.backend.load_resources(image_height, image_width, batch_size) def set_random_seed(self, seed): - # Initialize noise generator. Usually, it is done before a batch of inference. - self.generator = torch.Generator(device="cuda").manual_seed(seed) if isinstance(seed, int) else None + if isinstance(seed, int): + self.generator.manual_seed(seed) + else: + self.generator.seed() + + def get_current_seed(self): + return self.generator.initial_seed() def teardown(self): for e in self.events.values(): @@ -452,8 +457,18 @@ def save_images(self, images, pipeline, prompt): images = self.to_pil_image(images) random_session_id = str(random.randint(1000, 9999)) for i, image in enumerate(images): + seed = str(self.get_current_seed()) image_path = os.path.join( - self.output_dir, image_name_prefix + str(i + 1) + "-" + random_session_id + ".png" + self.output_dir, image_name_prefix + str(i + 1) + "-" + random_session_id + "-" + seed + ".png" ) print(f"Saving image {i+1} / {len(images)} to: {image_path}") - image.save(image_path) + + from PIL import PngImagePlugin + + metadata = PngImagePlugin.PngInfo() + metadata.add_text("prompt", prompt[i]) + metadata.add_text("batch_size", str(len(images))) + metadata.add_text("denoising_steps", str(self.denoising_steps)) + metadata.add_text("actual_steps", str(self.actual_steps)) + metadata.add_text("seed", seed) + image.save(image_path, "PNG", pnginfo=metadata) diff --git a/onnxruntime/python/tools/transformers/onnx_model_conformer.py b/onnxruntime/python/tools/transformers/onnx_model_conformer.py new file mode 100644 index 0000000000000..1506d85f53fd4 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_model_conformer.py @@ -0,0 +1,33 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +from typing import Optional + +from fusion_attention import AttentionMask +from fusion_conformer_attention import FusionConformerAttention +from fusion_options import FusionOptions +from onnx_model_bert import BertOnnxModel + +logger = logging.getLogger(__name__) + + +class ConformerOnnxModel(BertOnnxModel): + def __init__(self, model, num_heads, hidden_size): + super().__init__(model, num_heads, hidden_size) + self.attention_mask = AttentionMask(self) + self.attention_fusion = FusionConformerAttention(self, self.hidden_size, self.num_heads, self.attention_mask) + + def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): + self.attention_fusion.use_multi_head_attention = False if options is None else options.use_multi_head_attention + self.attention_fusion.disable_multi_head_attention_bias = ( + False if options is None else options.disable_multi_head_attention_bias + ) + super().optimize(options, add_dynamic_axes) + + def fuse_attention(self): + self.attention_fusion.apply() + + def preprocess(self): + self.adjust_reshape_and_expand() diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 94a757320e598..6842a97fe0c77 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -32,6 +32,7 @@ from onnx_model_bert_keras import BertOnnxModelKeras from onnx_model_bert_tf import BertOnnxModelTF from onnx_model_clip import ClipOnnxModel +from onnx_model_conformer import ConformerOnnxModel from onnx_model_gpt2 import Gpt2OnnxModel from onnx_model_t5 import T5OnnxModel from onnx_model_tnlr import TnlrOnnxModel @@ -56,6 +57,7 @@ "unet": (UnetOnnxModel, "pytorch", 1), # UNet in Stable Diffusion "vae": (VaeOnnxModel, "pytorch", 1), # UAE in Stable Diffusion "vit": (BertOnnxModel, "pytorch", 1), + "conformer": (ConformerOnnxModel, "pytorch", 1), } diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 918ee0e6eb976..3c6217915bef0 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -72,21 +72,17 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop MlasTranspose(input1_f_vals.data(), input1_f_vals_trans.data(), K, N); #endif - int meta_rows; - int meta_cols; - MlasBlockwiseQuantMetaShape((int)block_size, true, (int)K, (int)N, meta_rows, meta_cols); + int q_rows, q_cols; + MlasBlockwiseQuantizedShape((int)block_size, true, (int)K, (int)N, q_rows, q_cols); - int q_rows; - int q_cols; - MlasBlockwiseQuantizedShape((int)block_size, true, (int)K, (int)N, q_rows, q_cols); + size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; + MlasBlockwiseQuantizedBufferSizes(4, static_cast(block_size), /* columnwise */ true, + static_cast(K), static_cast(N), + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); - std::vector input1_vals(q_rows * q_cols); - std::vector scales(meta_rows * meta_cols); - - // TODO!! THIS SHOULD BE PROVIDED BY MLAS - // sub 8b packing always happen on the column dimension - const int packed_meta_rows = (meta_rows * QBits + 7) / 8; - std::vector zp(packed_meta_rows * meta_cols); + std::vector input1_vals(q_data_size_in_bytes); + std::vector scales(q_scale_size); + std::vector zp(q_zp_size_in_bytes); QuantizeDequantize(input1_f_vals, input1_vals, @@ -115,9 +111,9 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop if (use_float16) { test.AddInput("A", {M, K}, ToFloat16(input0_vals), false); test.AddInput("B", {q_cols, q_rows}, input1_vals, true); - test.AddInput("scales", {meta_cols * meta_rows}, ToFloat16(scales), true); + test.AddInput("scales", {static_cast(q_scale_size)}, ToFloat16(scales), true); if (has_zeropoint) { - test.AddInput("zero_points", {meta_cols * packed_meta_rows}, zp, true); + test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); } test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); @@ -129,9 +125,9 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, bool has_zerop } else { test.AddInput("A", {M, K}, input0_vals, false); test.AddInput("B", {q_cols, q_rows}, input1_vals, true); - test.AddInput("scales", {meta_cols * meta_rows}, scales, true); + test.AddInput("scales", {static_cast(q_scale_size)}, scales, true); if (has_zeropoint) { - test.AddInput("zero_points", {meta_cols * packed_meta_rows}, zp, true); + test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); } test.AddOutput("Y", {M, N}, expected_vals); diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc new file mode 100644 index 0000000000000..ebb0261deefa5 --- /dev/null +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -0,0 +1,423 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +static void RunMoETest( + const std::vector& input, + const std::vector& router_probs, + const std::vector& fc1_experts_weights, + const std::vector& fc2_experts_weights, + const std::vector& fc1_experts_bias, + const std::vector& fc2_experts_bias, + const std::vector& output_data, + int num_rows, + int num_experts, + int hidden_size, + int inter_size, + std::string activation_type, + bool use_float16 = false) { + int min_cuda_architecture = use_float16 ? 530 : 0; + + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + if (enable_cuda) { + OpTester tester("MoE", 1, onnxruntime::kMSDomain); + tester.AddAttribute("k", static_cast(1)); + tester.AddAttribute("activation_type", activation_type); + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + std::vector fc1_experts_weights_dims = {num_experts, hidden_size, inter_size}; + std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size}; + std::vector fc1_experts_bias_dims = {num_experts, inter_size}; + std::vector fc2_experts_bias_dims = {num_experts, hidden_size}; + std::vector output_dims = {num_rows, hidden_size}; + + if (use_float16) { + tester.AddInput("input", input_dims, ToFloat16(input)); + tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, ToFloat16(fc1_experts_weights)); + tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, ToFloat16(fc2_experts_weights)); + tester.AddInput("fc1_experts_bias", fc1_experts_bias_dims, ToFloat16(fc1_experts_bias)); + tester.AddInput("fc2_experts_bias", fc2_experts_bias_dims, ToFloat16(fc2_experts_bias)); + tester.AddOutput("output", output_dims, ToFloat16(output_data)); + } else { + tester.AddInput("input", input_dims, input); + tester.AddInput("router_probs", router_probs_dims, router_probs); + tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + tester.AddInput("fc1_experts_bias", fc1_experts_bias_dims, fc1_experts_bias); + tester.AddInput("fc2_experts_bias", fc2_experts_bias_dims, fc2_experts_bias); + tester.AddOutput("output", output_dims, output_data); + } + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MoETest, MoETest_Gelu) { + int num_rows = 4; + int num_experts = 4; + int hidden_size = 8; + int inter_size = 16; + + const std::vector input = { + -1.1200173f, -0.45884353f, -1.2929888f, 1.0784022f, 0.116372705f, 0.26902613f, -1.8818876f, -0.5457026f, + 0.22222236f, -0.28868636f, 0.6692926f, 1.4944887f, 0.02431708f, -0.49781424f, 0.7378293f, 1.276276f, + -0.15469065f, -0.28456813f, -0.6296439f, -0.24855971f, 0.80565417f, -1.1018785f, -0.74082595f, 0.82407707f, + -0.95033455f, 0.659333f, -0.68629056f, -0.2916592f, 1.869919f, -1.1053563f, -0.14417848f, -0.34625578f}; + const std::vector router_probs = { + -0.84837115f, 0.100507565f, -0.10548311f, 0.40957215f, 1.0159845f, 0.26919764f, 0.021741152f, -0.34184334f, + -0.71324956f, 0.29018253f, -0.18227568f, 0.31496462f, -0.48426327f, -1.006643f, -0.100081146f, -0.07692295f}; + const std::vector fc1_experts_weights = { + 0.14731085f, 0.52229995f, 0.14753294f, 0.22475791f, 0.20864725f, 0.6708725f, 0.20204341f, 0.4890914f, + 0.52103406f, 0.8223115f, 0.122039974f, 0.15674388f, 0.20966923f, 0.8499667f, 0.3202675f, 0.92174435f, + 0.6808038f, 0.563313f, 0.496278f, 0.40115923f, 0.5627332f, 0.38582766f, 0.49648678f, 0.5637965f, + 0.10889745f, 0.23793429f, 0.90374637f, 0.09422666f, 0.4640969f, 0.99461937f, 0.6806185f, 0.5141565f, + 0.066695035f, 0.74768895f, 0.14385962f, 0.35806787f, 0.33224183f, 0.4259563f, 0.50546914f, 0.91240376f, + 0.5624194f, 0.9478464f, 0.8058562f, 0.18389302f, 0.72425205f, 0.14655197f, 0.28808743f, 0.64706135f, + 0.66509604f, 0.875114f, 0.33904207f, 0.50080043f, 0.7574118f, 0.016453922f, 0.8614903f, 0.08653879f, + 0.50689125f, 0.41499162f, 0.23666352f, 0.5660855f, 0.91345936f, 0.35384023f, 0.20315295f, 0.31508058f, + 0.0044258237f, 0.725697f, 0.25986814f, 0.16632986f, 0.21194929f, 0.787478f, 0.76478684f, 0.8837609f, + 0.68136156f, 0.33302015f, 0.36027592f, 0.647715f, 0.91101736f, 0.6359461f, 0.26342732f, 0.2649613f, + 0.02726549f, 0.608024f, 0.21940875f, 0.054212093f, 0.93843824f, 0.1752944f, 0.44311923f, 0.64324677f, + 0.51592916f, 0.16355914f, 0.09583914f, 0.8985412f, 0.58141935f, 0.91481227f, 0.3323797f, 0.6472777f, + 0.3856619f, 0.47776443f, 0.1954779f, 0.66910046f, 0.65808296f, 0.4896857f, 0.38754892f, 0.1917851f, + 0.8457724f, 0.12778795f, 0.70483273f, 0.33187324f, 0.258766f, 0.58982253f, 0.24027151f, 0.6152024f, + 0.5981904f, 0.12875527f, 0.5832493f, 0.7129646f, 0.6979155f, 0.43706065f, 0.09010619f, 0.42292297f, + 0.67365384f, 0.31756145f, 0.68979055f, 0.8329813f, 0.2389242f, 0.5049309f, 0.7067495f, 0.5391889f, + 0.54176575f, 0.5624327f, 0.10692614f, 0.5392941f, 0.8462349f, 0.9505569f, 0.79387546f, 0.5670015f, + 0.7335071f, 0.25676018f, 0.08565581f, 0.07003945f, 0.99880487f, 0.8173947f, 0.15438312f, 0.6956213f, + 0.8775838f, 0.9998074f, 0.93719745f, 0.8873769f, 0.38537037f, 0.32452917f, 0.9105244f, 0.7801898f, + 0.19911051f, 0.9495086f, 0.7415793f, 0.77256775f, 0.18661183f, 0.6434499f, 0.32471877f, 0.8906783f, + 0.4100297f, 0.69465625f, 0.5888109f, 0.7127341f, 0.33008623f, 0.7437857f, 0.15076452f, 0.6129275f, + 0.16170406f, 0.006731212f, 0.09847212f, 0.89473504f, 0.7705178f, 0.96910787f, 0.9005606f, 0.053477287f, + 0.15878445f, 0.4192087f, 0.17528385f, 0.84719825f, 0.121996105f, 0.25604928f, 0.016954303f, 0.21612722f, + 0.91123873f, 0.90938f, 0.85791886f, 0.88606364f, 0.94459325f, 0.3719685f, 0.72000104f, 0.9454652f, + 0.6654094f, 0.9998382f, 0.75933146f, 0.81082416f, 0.32500392f, 0.73991376f, 0.5574533f, 0.38059133f, + 0.21814507f, 0.21944171f, 0.11525959f, 0.83566517f, 0.8554656f, 0.44309366f, 0.210657f, 0.88645273f, + 0.81974447f, 0.537167f, 0.26393235f, 0.9595239f, 0.70447034f, 0.12042731f, 0.97854143f, 0.8796869f, + 0.31775457f, 0.78107727f, 0.21590549f, 0.42164284f, 0.9245506f, 0.52065957f, 0.14639091f, 0.33288354f, + 0.36427742f, 0.4035356f, 0.5478503f, 0.9624148f, 0.5267702f, 0.19128f, 0.52562714f, 0.7397436f, + 0.7480201f, 0.04303074f, 0.41052878f, 0.12842774f, 0.2866572f, 0.6801467f, 0.1449349f, 0.68586344f, + 0.92438906f, 0.5327942f, 0.16675615f, 0.32085752f, 0.60918206f, 0.11884099f, 0.74840516f, 0.04606521f, + 0.01935333f, 0.014169693f, 0.39856833f, 0.83621645f, 0.026760519f, 0.91559356f, 0.29998857f, 0.64644206f, + 0.52280146f, 0.049140453f, 0.9146645f, 0.7692217f, 0.99699783f, 0.7526061f, 0.1699655f, 0.9172919f, + 0.5268722f, 0.73710823f, 0.09908545f, 0.35618675f, 0.009061217f, 0.30525374f, 0.6078656f, 0.10741913f, + 0.6593821f, 0.7684034f, 0.56965464f, 0.16545832f, 0.11234015f, 0.3457417f, 0.7194791f, 0.9931982f, + 0.7875145f, 0.44369537f, 0.6753082f, 0.009468555f, 0.07294935f, 0.73330396f, 0.2167924f, 0.74054784f, + 0.14703393f, 0.25234455f, 0.08815551f, 0.76092035f, 0.44905245f, 0.88480055f, 0.8094361f, 0.7766713f, + 0.51607805f, 0.345411f, 0.39128417f, 0.5664503f, 0.74785477f, 0.14970505f, 0.91963893f, 0.44563496f, + 0.08102721f, 0.22947109f, 0.94240886f, 0.9572636f, 0.036860168f, 0.85264915f, 0.7505796f, 0.79595923f, + 0.9232646f, 0.23052484f, 0.6578879f, 0.7046166f, 0.35225332f, 0.66732657f, 0.3561433f, 0.80913067f, + 0.3612727f, 0.31360215f, 0.6258745f, 0.6773468f, 0.25571418f, 0.54419917f, 0.78976786f, 0.45025164f, + 0.65216696f, 0.3794065f, 0.6752498f, 0.1378029f, 0.2059856f, 0.24620473f, 0.95950544f, 0.36545795f, + 0.49863482f, 0.25775224f, 0.99914503f, 0.9883351f, 0.122906685f, 0.09466505f, 0.12100351f, 0.49758863f, + 0.37254804f, 0.17272717f, 0.32066393f, 0.59446543f, 0.23875463f, 0.61079127f, 0.38534206f, 0.25771832f, + 0.56869274f, 0.9111291f, 0.16196036f, 0.5232172f, 0.31561613f, 0.99065316f, 0.025618374f, 0.0206694f, + 0.9926925f, 0.18365502f, 0.5958617f, 0.45684695f, 0.3946715f, 0.3883261f, 0.8177203f, 0.5238985f, + 0.013192713f, 0.20481992f, 0.32954985f, 0.7516082f, 0.17643315f, 0.9714598f, 0.38863534f, 0.410219f, + 0.891779f, 0.75130385f, 0.92406017f, 0.7892222f, 0.34832305f, 0.1682638f, 0.46279848f, 0.9138188f, + 0.3321901f, 0.036315024f, 0.7049642f, 0.9867357f, 0.3576584f, 0.08598822f, 0.046470165f, 0.6252997f, + 0.46214014f, 0.24750638f, 0.60106593f, 0.6898794f, 0.8976595f, 0.8881911f, 0.42515814f, 0.059116423f, + 0.048188448f, 0.9668448f, 0.7210276f, 0.7179537f, 0.06738949f, 0.96300787f, 0.97367156f, 0.95143014f, + 0.07820749f, 0.3113383f, 0.1561181f, 0.9734828f, 0.28516f, 0.27172273f, 0.76195645f, 0.26870382f, + 0.25373894f, 0.45626426f, 0.45194024f, 0.11051077f, 0.91683406f, 0.27943915f, 0.67735744f, 0.9348918f, + 0.7521582f, 0.57078993f, 0.9254285f, 0.5672131f, 0.2686717f, 0.97299975f, 0.61834025f, 0.012159586f, + 0.3576542f, 0.15941626f, 0.9383765f, 0.41742706f, 0.044237554f, 0.46856833f, 0.81400645f, 0.6299002f, + 0.6581022f, 0.5464366f, 0.68640935f, 0.378174f, 0.3010999f, 0.032645762f, 0.12333155f, 0.71670127f, + 0.20394331f, 0.57173324f, 0.6595957f, 0.53540194f, 0.17582512f, 0.9781642f, 0.20925027f, 0.9112503f, + 0.10224587f, 0.37972575f, 0.7719844f, 0.29570967f, 0.9200215f, 0.15592176f, 0.080114245f, 0.27454042f, + 0.5808252f, 0.96037793f, 0.26129955f, 0.6788141f, 0.37464648f, 0.39156884f, 0.8676517f, 0.112507045f, + 0.55310667f, 0.9702046f, 0.4312939f, 0.88821906f, 0.3460216f, 0.9024811f, 0.016334832f, 0.42793816f, + 0.4121768f, 0.6620425f, 0.6961637f, 0.88390845f, 0.425507f, 0.48017246f, 0.8424056f, 0.36471343f, + 0.9383168f, 0.16709393f, 0.44589508f, 0.47314453f, 0.72310495f, 0.84183806f, 0.4207481f, 0.0857597f, + 0.7477461f, 0.6495659f, 0.70084965f, 0.19156617f, 0.8217978f, 0.9735775f, 0.5433857f, 0.032975793f, + 0.85099494f, 0.12927437f, 0.61493605f, 0.5726589f, 0.26598173f, 0.6740978f, 0.052783668f, 0.61387974f}; + const std::vector fc2_experts_weights = { + 0.18302453f, 0.44593316f, 0.5643144f, 0.9259722f, 0.26143986f, 0.82031804f, 0.4364831f, 0.2625361f, + 0.06460017f, 0.04124081f, 0.98830533f, 0.37530023f, 0.5249744f, 0.63555616f, 0.8398661f, 0.92673707f, + 0.9055086f, 0.12955844f, 0.4198916f, 0.20413119f, 0.21432412f, 0.6186035f, 0.969324f, 0.099448025f, + 0.80260223f, 0.24076664f, 0.40261286f, 0.89688545f, 0.38691485f, 0.5455279f, 0.15048373f, 0.92562044f, + 0.43536508f, 0.13430476f, 0.64640516f, 0.14449131f, 0.10324633f, 0.5304596f, 0.8964218f, 0.358508f, + 0.73533344f, 0.9296606f, 0.83163047f, 0.23771948f, 0.44519007f, 0.34265757f, 0.09793854f, 0.5002066f, + 0.87621754f, 0.9212578f, 0.54665035f, 0.6135615f, 0.28353918f, 0.8774212f, 0.29194576f, 0.1526736f, + 0.57699674f, 0.7996927f, 0.04920423f, 0.95198375f, 0.67986554f, 0.14969361f, 0.39229625f, 0.93378997f, + 0.11638266f, 0.3538614f, 0.66399014f, 0.06195748f, 0.7740991f, 0.7602738f, 0.81010276f, 0.18122643f, + 0.9980005f, 0.20361924f, 0.99917024f, 0.020154774f, 0.054515004f, 0.80709815f, 0.55225646f, 0.52884465f, + 0.22312081f, 0.29026228f, 0.35380626f, 0.012922287f, 0.52598435f, 0.58842945f, 0.4995767f, 0.66146517f, + 0.9744255f, 0.632942f, 0.3169638f, 0.29422665f, 0.18009722f, 0.15339059f, 0.41947508f, 0.4115672f, + 0.72243124f, 0.2862816f, 0.89860183f, 0.14915991f, 0.5014211f, 0.94945997f, 0.99719256f, 0.21036887f, + 0.5890645f, 0.55906135f, 0.26557416f, 0.32725257f, 0.635427f, 0.1523174f, 0.58249784f, 0.71636236f, + 0.30296493f, 0.9153206f, 0.46709478f, 0.72685635f, 0.9951532f, 0.34716582f, 0.7717041f, 0.3569854f, + 0.4269635f, 0.41526443f, 0.4968937f, 0.3111158f, 0.61719346f, 0.5188402f, 0.8169449f, 0.39879733f, + 0.5501401f, 0.31400484f, 0.08127314f, 0.7023336f, 0.56397897f, 0.29975814f, 0.33094752f, 0.63076067f, + 0.40959156f, 0.82673794f, 0.52832156f, 0.68886834f, 0.7178481f, 0.37731683f, 0.71633244f, 0.86896664f, + 0.5230092f, 0.59784645f, 0.5181678f, 0.8461837f, 0.28890234f, 0.23421508f, 0.7178768f, 0.06484294f, + 0.5080162f, 0.27005446f, 0.8300168f, 0.034480453f, 0.8031663f, 0.9946784f, 0.60117006f, 0.46668667f, + 0.9921749f, 0.28632385f, 0.45993322f, 0.28104752f, 0.43097937f, 0.60866946f, 0.5667807f, 0.40556252f, + 7.969141e-05f, 0.52560204f, 0.48518902f, 0.5752184f, 0.8831251f, 0.9860047f, 0.20335877f, 0.46882278f, + 0.2996632f, 0.03917718f, 0.13617045f, 0.96928054f, 0.79153055f, 0.76857555f, 0.7778716f, 0.102760494f, + 0.5525096f, 0.9653573f, 0.22095704f, 0.94479716f, 0.63141924f, 0.8517718f, 0.28580618f, 0.73050886f, + 0.05675614f, 0.46825224f, 0.6667756f, 0.6499472f, 0.91840404f, 0.99132854f, 0.9548785f, 0.8356961f, + 0.851531f, 0.43548512f, 0.111976564f, 0.31438643f, 0.44386774f, 0.22980672f, 0.75558543f, 0.6755136f, + 0.58067596f, 0.62078035f, 0.93922615f, 0.6821157f, 0.061530292f, 0.13705963f, 0.7203748f, 0.5681396f, + 0.7438458f, 0.0006400347f, 0.038565338f, 0.8066132f, 0.81982285f, 0.047644496f, 0.68979263f, 0.109577894f, + 0.8786539f, 0.6568952f, 0.99439347f, 0.0070040226f, 0.018661916f, 0.838051f, 0.94391155f, 0.80634f, + 0.8324149f, 0.078864336f, 0.8619068f, 0.027926445f, 0.61170083f, 0.17248261f, 0.30140227f, 0.5885344f, + 0.30341f, 0.42088854f, 0.02608782f, 0.02856338f, 0.69368154f, 0.28836077f, 0.19580519f, 0.30270886f, + 0.09121573f, 0.100299895f, 0.79918617f, 0.75412107f, 0.56660175f, 0.22687018f, 0.6663505f, 0.5224626f, + 0.1426636f, 0.6075949f, 0.95527196f, 0.008196831f, 0.0028039217f, 0.5640625f, 0.87651116f, 0.19575512f, + 0.61006856f, 0.85149264f, 0.6541582f, 0.6082054f, 0.998863f, 0.82573634f, 0.21878648f, 0.54321826f, + 0.7554362f, 0.94095474f, 0.002533555f, 0.77075267f, 0.35483408f, 0.010389388f, 0.610987f, 0.22779316f, + 0.5708561f, 0.17537653f, 0.12373549f, 0.4575745f, 0.33203715f, 0.79243237f, 0.54310906f, 0.8902793f, + 0.5937015f, 0.33921933f, 0.8386668f, 0.52732253f, 0.59384584f, 0.3391887f, 0.5017944f, 0.40386343f, + 0.45749134f, 0.110060334f, 0.49692506f, 0.084977865f, 0.3924346f, 0.7897731f, 0.15232486f, 0.16297412f, + 0.37791175f, 0.36293298f, 0.5846437f, 0.5830078f, 0.75354826f, 0.15555972f, 0.4647144f, 0.7796456f, + 0.93248576f, 0.46352726f, 0.2106899f, 0.6437313f, 0.78473866f, 0.18762505f, 0.20985329f, 0.7209991f, + 0.464967f, 0.02775067f, 0.21170747f, 0.7027664f, 0.33041215f, 0.8451145f, 0.89526993f, 0.57273495f, + 0.46046263f, 0.34128642f, 0.47471708f, 0.59101045f, 0.11807448f, 0.38050216f, 0.08409953f, 0.80687743f, + 0.18158185f, 0.9567719f, 0.3711096f, 0.21356237f, 0.74022657f, 0.57453954f, 0.846228f, 0.70873487f, + 0.018330276f, 0.8162452f, 0.40584308f, 0.27901447f, 0.81752694f, 0.86466515f, 0.060534656f, 0.45478833f, + 0.9106033f, 0.6936434f, 0.92123467f, 0.32865065f, 0.22417879f, 0.9299548f, 0.70841146f, 0.97999126f, + 0.2911517f, 0.17896658f, 0.44139355f, 0.029210031f, 0.6959876f, 0.8687942f, 0.62002844f, 0.45059657f, + 0.74790317f, 0.18262434f, 0.98912156f, 0.0028281808f, 0.021027386f, 0.38184917f, 0.90842223f, 0.5500629f, + 0.69202286f, 0.13349658f, 0.6823429f, 0.44412827f, 0.7004118f, 0.8531213f, 0.7173401f, 0.4574679f, + 0.46920043f, 0.18640989f, 0.31914896f, 0.82491904f, 0.29950172f, 0.8105199f, 0.30173403f, 0.38355058f, + 0.5106411f, 0.04116726f, 0.49500751f, 0.44960213f, 0.45508182f, 0.4000479f, 0.89418864f, 0.8689936f, + 0.16112137f, 0.7322634f, 0.10780871f, 0.07433933f, 0.652841f, 0.50734824f, 0.26674682f, 0.017748117f, + 0.30643195f, 0.66699976f, 0.03719926f, 0.014267266f, 0.56343627f, 0.13979793f, 0.061959863f, 0.3073569f, + 0.41949958f, 0.045647383f, 0.16613615f, 0.5327839f, 0.028514147f, 0.4297228f, 0.17714864f, 0.15338135f, + 0.6965155f, 0.11515516f, 0.1210829f, 0.78514075f, 0.59348315f, 0.9553564f, 0.36635226f, 0.25849247f, + 0.45372677f, 0.5025297f, 0.88132215f, 0.0019600391f, 0.46439964f, 0.7211761f, 0.22465849f, 0.2459296f, + 0.7416339f, 0.020907402f, 0.6184779f, 0.112906754f, 0.7485309f, 0.072479784f, 0.8074024f, 0.026683688f, + 0.07971662f, 0.50736845f, 0.8939942f, 0.0718022f, 0.27697015f, 0.9391413f, 0.4161513f, 0.7071423f, + 0.019000888f, 0.34275955f, 0.24608392f, 0.9215306f, 0.70751995f, 0.13516217f, 0.5806135f, 0.49425328f, + 0.29456508f, 0.21446168f, 0.3340807f, 0.89411324f, 0.14157385f, 0.14382833f, 0.34574044f, 0.50869817f, + 0.63610595f, 0.51500404f, 0.37963718f, 0.19682491f, 0.41028368f, 0.29872334f, 0.9039644f, 0.013295233f, + 0.1810705f, 0.093204916f, 0.4086216f, 0.8896367f, 0.9382696f, 0.06472236f, 0.47833657f, 0.7934831f, + 0.7203987f, 0.9095519f, 0.4861309f, 0.16405362f, 0.83076525f, 0.3285427f, 0.7588931f, 0.37678176f, + 0.71254706f, 0.949713f, 0.96492773f, 0.044967473f, 0.16925985f, 0.2932666f, 0.18114948f, 0.97975004f, + 0.4558406f, 0.16832972f, 0.27750528f, 0.2238177f, 0.7039947f, 0.06387442f, 0.033798456f, 0.007119417f}; + const std::vector fc1_experts_bias = { + 0.71526206f, 0.7472273f, 0.18946046f, 0.6239893f, 0.86909235f, 0.5726507f, 0.3942092f, 0.5369412f, + 0.44638616f, 0.7517496f, 0.16049433f, 0.75355124f, 0.7818118f, 0.19706267f, 0.9082818f, 0.9910924f, + 0.30288565f, 0.3599528f, 0.74917775f, 0.10828978f, 0.697729f, 0.61665237f, 0.81516486f, 0.0656966f, + 0.0846076f, 0.72456455f, 0.6801054f, 0.034616888f, 0.22117025f, 0.042510748f, 0.14178854f, 0.27440017f, + 0.91376925f, 0.40047455f, 0.7871756f, 0.97484046f, 0.7278661f, 0.052394807f, 0.75161135f, 0.6907173f, + 0.8875328f, 0.0067828894f, 0.807508f, 0.9092707f, 0.034817636f, 0.55231315f, 0.92683655f, 0.13634592f, + 0.66405964f, 0.7209387f, 0.63104504f, 0.9971379f, 0.9093898f, 0.9289774f, 0.4376766f, 0.9193563f, + 0.03404367f, 0.23018533f, 0.39305943f, 0.3514716f, 0.96184736f, 0.73583263f, 0.8219065f, 0.8401047f}; + const std::vector fc2_experts_bias = { + 0.12649822f, 0.4420895f, 0.5730123f, 0.63004625f, 0.7571163f, 0.3010466f, 0.3492328f, 0.91837066f, + 0.36580783f, 0.15267932f, 0.8390199f, 0.83857775f, 0.34321654f, 0.40003997f, 0.13106f, 0.08245313f, + 0.68802476f, 0.28640372f, 0.89804775f, 0.09964341f, 0.43088746f, 0.5107959f, 0.75697356f, 0.90466535f, + 0.83860224f, 0.720098f, 0.2705031f, 0.14292616f, 0.052693605f, 0.5248023f, 0.9849401f, 0.40502876f}; + const std::vector output = { + 0.2552814f, 0.17651685f, 0.0034551744f, -0.123282805f, 0.0073816925f, 0.004265253f, 0.16927283f, -0.05276826f, + 9.555821f, 7.6907287f, 10.626425f, 7.0543795f, 8.10093f, 10.3664465f, 10.925815f, 8.737018f, + 0.565234f, 0.17098689f, 0.10810414f, 0.43916586f, 0.3535297f, 0.45673048f, 0.3853893f, 0.18613164f, + 1.3354061f, 0.5049282f, 0.72775036f, 0.90331376f, 1.2945517f, 0.9123066f, 1.1995136f, 0.7708638f}; + + RunMoETest(input, + router_probs, + fc1_experts_weights, + fc2_experts_weights, + fc1_experts_bias, + fc2_experts_bias, + output, + num_rows, + num_experts, + hidden_size, + inter_size, + "gelu"); +} + +TEST(MoETest, MoETest_Relu) { + int num_rows = 4; + int num_experts = 4; + int hidden_size = 8; + int inter_size = 16; + + const std::vector input = { + 0.7670296f, -0.93721074f, -2.330477f, -0.78088343f, 0.8250065f, 1.2206652f, -0.06297584f, 1.1463639f, + 1.2215378f, -0.31372663f, -0.7234253f, -0.3627346f, 0.44249064f, 0.19418247f, -0.49998695f, -0.55005103f, + 0.023851749f, -1.5203826f, 0.52939993f, -0.39082858f, -1.9291036f, 0.034976702f, -0.48336256f, -1.226073f, + -0.33963847f, 0.0073261578f, -0.0521804f, 1.16749f, 1.7302082f, 2.0561688f, -0.2347232f, -1.3456243f}; + const std::vector router_probs = { + -0.08146476f, -0.40439552f, 1.0100367f, -0.7724162f, -0.08113786f, -0.36328858f, 0.3688482f, -0.013465762f, + -0.32420647f, -0.3815508f, 0.79585606f, 0.14430691f, -0.21869831f, 0.11483674f, -0.11992836f, 0.35216537f}; + const std::vector fc1_experts_weights = { + 0.81960344f, 0.9296998f, 0.45050132f, 0.38805157f, 0.50729614f, 0.47014588f, 0.62020564f, 0.6401168f, + 0.045871615f, 0.31548113f, 0.92106473f, 0.6947775f, 0.4751312f, 0.19854712f, 0.19409746f, 0.052116573f, + 0.3370188f, 0.6688521f, 0.8188108f, 0.73084867f, 0.058027983f, 0.19931877f, 0.42109168f, 0.98367476f, + 0.57232875f, 0.37051463f, 0.7068576f, 0.30955923f, 0.17637217f, 0.8649436f, 0.2726491f, 0.39976662f, + 0.0025978684f, 0.8346353f, 0.8788173f, 0.6822241f, 0.1513629f, 0.0065300465f, 0.093910515f, 0.8728501f, + 0.7400529f, 0.9207522f, 0.76193494f, 0.6265461f, 0.49510366f, 0.11974698f, 0.07161391f, 0.032325685f, + 0.704681f, 0.254516f, 0.3993737f, 0.21224737f, 0.40888822f, 0.14808255f, 0.17329216f, 0.6658554f, + 0.3514018f, 0.8086716f, 0.33959562f, 0.13321638f, 0.41178054f, 0.2576263f, 0.3470292f, 0.024002194f, + 0.77974546f, 0.15189773f, 0.75130886f, 0.7268921f, 0.85721636f, 0.11647397f, 0.8595984f, 0.2636242f, + 0.6855346f, 0.96955734f, 0.42948407f, 0.49613327f, 0.38488472f, 0.08250773f, 0.73995143f, 0.003641069f, + 0.81039995f, 0.87411255f, 0.9728532f, 0.38206023f, 0.08917904f, 0.61241513f, 0.77621365f, 0.0023456216f, + 0.38650817f, 0.20027226f, 0.45626813f, 0.25389326f, 0.2956162f, 0.34127057f, 0.024847984f, 0.91025376f, + 0.9191656f, 0.42156547f, 0.44305897f, 0.29594004f, 0.04846859f, 0.013427794f, 0.6858292f, 0.22547692f, + 0.17856151f, 0.4609884f, 0.33349442f, 0.3382396f, 0.5160656f, 0.3939438f, 0.3278438f, 0.26059705f, + 0.0930863f, 0.9192536f, 0.29990643f, 0.63248974f, 0.32651705f, 0.54063064f, 0.9661502f, 0.73036134f, + 0.06670016f, 0.6984514f, 0.9746214f, 0.63154167f, 0.83521235f, 0.99294376f, 0.4233855f, 0.6037772f, + 0.15248245f, 0.39696145f, 0.8702919f, 0.7563229f, 0.18360549f, 0.099057496f, 0.15831816f, 0.00656116f, + 0.114180505f, 0.3763513f, 0.8374386f, 0.5836911f, 0.11969727f, 0.09888804f, 0.74873763f, 0.12807935f, + 0.43843627f, 0.739853f, 0.26859397f, 0.44548005f, 0.45647776f, 0.38170832f, 0.24648392f, 0.054280818f, + 0.0958215f, 0.23226917f, 0.98291886f, 0.25849265f, 0.16423601f, 0.6211971f, 0.63780516f, 0.77395487f, + 0.8800602f, 0.7784371f, 0.004249513f, 0.5443443f, 0.80287653f, 0.45378727f, 0.20536041f, 0.9766699f, + 0.31298608f, 0.21532774f, 0.04922247f, 0.52233416f, 0.72156656f, 0.6106814f, 0.59887487f, 0.12080628f, + 0.03305638f, 0.5088047f, 0.95591706f, 0.7884607f, 0.20888287f, 0.43509573f, 0.13140821f, 0.2587883f, + 0.5905492f, 0.77226925f, 0.91418463f, 0.04094696f, 0.8343076f, 0.14735395f, 0.6872336f, 0.92312264f, + 0.5070212f, 0.9549045f, 0.07397425f, 0.3090204f, 0.79162645f, 0.39106607f, 0.39764988f, 0.29160416f, + 0.84465307f, 0.7452516f, 0.66022503f, 0.21901816f, 0.09412521f, 0.5540803f, 0.6481394f, 0.26914406f, + 0.36010116f, 0.83768386f, 0.53982985f, 0.52255917f, 0.37694973f, 0.04720515f, 0.029871285f, 0.26099247f, + 0.2458393f, 0.6557768f, 0.35444462f, 0.30438894f, 0.9767149f, 0.67416143f, 0.85645115f, 0.25794363f, + 0.2957666f, 0.68377024f, 0.16686243f, 0.17314798f, 0.47585016f, 0.31711966f, 0.125171f, 0.7965795f, + 0.90208143f, 0.58111167f, 0.41294336f, 0.036863506f, 0.31788063f, 0.6272928f, 0.73576546f, 0.43679124f, + 0.30232358f, 0.77861303f, 0.10180014f, 0.816009f, 0.30602258f, 0.5076527f, 0.40119207f, 0.5606195f, + 0.3489008f, 0.8635635f, 0.48700142f, 0.89029974f, 0.98074025f, 0.25640452f, 0.13524544f, 0.901151f, + 0.89180696f, 0.11822635f, 0.46134835f, 0.006936848f, 0.09070045f, 0.59657127f, 0.6330173f, 0.6059905f, + 0.36391765f, 0.96128887f, 0.571489f, 0.2049576f, 0.4716931f, 0.6200726f, 0.67509633f, 0.14645958f, + 0.6873948f, 0.24455917f, 0.08452982f, 0.22689629f, 0.9822047f, 0.9274289f, 0.9477422f, 0.7935056f, + 0.87772477f, 0.43307513f, 0.22488606f, 0.7498283f, 0.24090862f, 0.16256708f, 0.34033298f, 0.6049296f, + 0.7573983f, 0.3057955f, 0.20571685f, 0.56744653f, 0.2052834f, 0.17446929f, 0.76062596f, 0.4160077f, + 0.9568925f, 0.9863913f, 0.64955276f, 0.67207885f, 0.61514187f, 0.50783044f, 0.46363378f, 0.50687206f, + 0.6867124f, 0.9648854f, 0.37042046f, 0.2886421f, 0.37891757f, 0.25843787f, 0.58501935f, 0.8732242f, + 0.8909887f, 0.72956276f, 0.13203424f, 0.23164761f, 0.3901443f, 0.40783793f, 0.54112387f, 0.041014254f, + 0.65562236f, 0.11856395f, 0.18362767f, 0.08430874f, 0.9356598f, 0.026530087f, 0.8771834f, 0.48319155f, + 0.4418506f, 0.81273925f, 0.4537862f, 0.81357706f, 0.8615075f, 0.06589496f, 0.692392f, 0.5943895f, + 0.60750586f, 0.5729957f, 0.6367655f, 0.2594666f, 0.43602943f, 0.97506f, 0.83592474f, 0.48121578f, + 0.029734552f, 0.5219139f, 0.15951324f, 0.90659577f, 0.19645631f, 0.4638992f, 0.38902867f, 0.5889769f, + 0.9705138f, 0.5475096f, 0.789582f, 0.8881108f, 0.9036556f, 0.32732427f, 0.38817167f, 0.7409689f, + 0.36356616f, 0.734132f, 0.39076614f, 0.16087383f, 0.70352167f, 0.576659f, 0.7229242f, 0.996743f, + 0.84136647f, 0.97399056f, 0.5267614f, 0.06989372f, 0.14923638f, 0.18941313f, 0.059375823f, 0.24937624f, + 0.039716125f, 0.038692355f, 0.20122272f, 0.0070830584f, 0.19309378f, 0.69065434f, 0.9170264f, 0.3512686f, + 0.3545606f, 0.76697665f, 0.25331455f, 0.26358372f, 0.80806476f, 0.064349174f, 0.5611374f, 0.941691f, + 0.58574325f, 0.6359719f, 0.20880443f, 0.49310172f, 0.5274922f, 0.62271714f, 0.694273f, 0.9344639f, + 0.11835027f, 0.51498765f, 0.25018185f, 0.10446805f, 0.45996118f, 0.059881568f, 0.8489496f, 0.5579074f, + 0.23052096f, 0.76128954f, 0.02678603f, 0.3066004f, 0.40259063f, 0.07512486f, 0.18205583f, 0.4183907f, + 0.8793823f, 0.9828271f, 0.8181312f, 0.20143801f, 0.17288941f, 0.9363466f, 0.6768587f, 0.51328385f, + 0.56766605f, 0.098151624f, 0.33305728f, 0.98130906f, 0.3766839f, 0.47491795f, 0.08483446f, 0.22029644f, + 0.4897902f, 0.18942028f, 0.4379952f, 0.7034796f, 0.0109113455f, 0.64850605f, 0.16939592f, 0.25597447f, + 0.69195485f, 0.8975601f, 0.36334568f, 0.29471546f, 0.04788208f, 0.24217117f, 0.062181532f, 0.38556474f, + 0.6020277f, 0.03156215f, 0.93655676f, 0.81369543f, 0.010527074f, 0.2611835f, 0.6630776f, 0.3972702f, + 0.44551176f, 0.27424216f, 0.9016098f, 0.22050089f, 0.9146384f, 0.53226113f, 0.6005109f, 0.8900659f, + 0.4176172f, 0.21532834f, 0.4191329f, 0.9055267f, 0.12900633f, 0.6134902f, 0.008604288f, 0.76215106f, + 0.68473387f, 0.5211961f, 0.71459657f, 0.50056237f, 0.7766764f, 0.10418975f, 0.42657375f, 0.7218073f, + 0.9979084f, 0.7546957f, 0.1364128f, 0.8845484f, 0.38850087f, 0.39324278f, 0.04554516f, 0.42129284f, + 0.8536634f, 0.5697224f, 0.20877302f, 0.65390605f, 0.3396778f, 0.956497f, 0.066022694f, 0.34206223f, + 0.017213225f, 0.3030849f, 0.6576238f, 0.9813073f, 0.58397317f, 0.99017924f, 0.59782606f, 0.788768f, + 0.9008311f, 0.91796166f, 0.22013813f, 0.959695f, 0.80288273f, 0.2662105f, 0.26139832f, 0.080626905f}; + const std::vector fc2_experts_weights = { + 0.6255686f, 0.09472537f, 0.71121234f, 0.65789884f, 0.065598905f, 0.63625044f, 0.45933473f, 0.7284089f, + 0.7868948f, 0.0029274821f, 0.95854944f, 0.919321f, 0.6989418f, 0.043019474f, 0.32138962f, 0.35509557f, + 0.37150103f, 0.78196156f, 0.6817853f, 0.89608955f, 0.31273842f, 0.6682699f, 0.6778976f, 0.08370459f, + 0.014990091f, 0.24055547f, 0.84227383f, 0.029270172f, 0.0647831f, 0.7801003f, 0.7697645f, 0.91119635f, + 0.12253064f, 0.13405013f, 0.75649333f, 0.9348151f, 0.7991694f, 0.57832605f, 0.66478735f, 0.97456336f, + 0.17739785f, 0.2729941f, 0.8497335f, 0.15788019f, 0.22429371f, 0.86499554f, 0.65776104f, 0.661535f, + 0.2880798f, 0.49309975f, 0.9576164f, 0.19988996f, 0.5039311f, 0.73779976f, 0.15482187f, 0.98558843f, + 0.25019473f, 0.379932f, 0.36471486f, 0.17417055f, 0.009367704f, 0.7819258f, 0.63283706f, 0.031699598f, + 0.1781866f, 0.994184f, 0.6911175f, 0.7006223f, 0.20085096f, 0.28080195f, 0.42452294f, 0.40856004f, + 0.15737581f, 0.5411925f, 0.549694f, 0.4366895f, 0.5693159f, 0.3018247f, 0.63012594f, 0.6885702f, + 0.2366305f, 0.004210472f, 0.7617172f, 0.61926836f, 0.24570602f, 0.981851f, 0.273876f, 0.8378734f, + 0.75366426f, 0.080795944f, 0.82247066f, 0.040263534f, 0.22299266f, 0.41664255f, 0.16297674f, 0.98845494f, + 0.39971018f, 0.69859487f, 0.053544044f, 0.7878332f, 0.34460813f, 0.11966437f, 0.5731115f, 0.7422309f, + 0.93269855f, 0.19460368f, 0.25394785f, 0.59613144f, 0.6356306f, 0.6922361f, 0.7744376f, 0.38662314f, + 0.7777848f, 0.8686458f, 0.36938924f, 0.8557286f, 0.74428976f, 0.9410264f, 0.21586305f, 0.2530955f, + 0.35543054f, 0.52536315f, 0.8000995f, 0.21456867f, 0.750327f, 0.3208093f, 0.80205464f, 0.47626138f, + 0.061956525f, 0.22487706f, 0.13812399f, 0.74798125f, 0.1647259f, 0.45834088f, 0.6078779f, 0.22580266f, + 0.644235f, 0.011788309f, 0.14224577f, 0.0469383f, 0.34876132f, 0.3178513f, 0.5715967f, 0.40754277f, + 0.735041f, 0.9583977f, 0.67939556f, 0.30301625f, 0.031807184f, 0.68110096f, 0.25227106f, 0.75443816f, + 0.83424246f, 0.69286025f, 0.9691554f, 0.9748982f, 0.60586995f, 0.13568163f, 0.94672066f, 0.26275212f, + 0.2638232f, 0.9183893f, 0.88740516f, 0.65107566f, 0.5313419f, 0.07941705f, 0.44809794f, 0.9795632f, + 0.6273294f, 0.542809f, 0.3961745f, 0.32560885f, 0.79801136f, 0.53083426f, 0.8252871f, 0.4115007f, + 0.7184546f, 0.70638496f, 0.57973206f, 0.8141865f, 0.81332296f, 0.96346164f, 0.88438797f, 0.37215167f, + 0.0766899f, 0.5914087f, 0.49563587f, 0.3695873f, 0.41627264f, 0.5235164f, 0.86481494f, 0.6558706f, + 0.32245284f, 0.29438752f, 0.37618434f, 0.3067485f, 0.9496114f, 0.76482266f, 0.95148784f, 0.5015968f, + 0.60083544f, 0.67338234f, 0.026723444f, 0.5446483f, 0.466555f, 0.21967298f, 0.112026334f, 0.9426372f, + 0.906533f, 0.73173434f, 0.97712487f, 0.29709607f, 0.41363865f, 0.6893093f, 0.4173867f, 0.4018826f, + 0.086719275f, 0.63433063f, 0.1978364f, 0.5181831f, 0.9874878f, 0.34609234f, 0.34240413f, 0.8016564f, + 0.31617337f, 0.4570613f, 0.96686924f, 0.29501313f, 0.14229488f, 0.22017813f, 0.36137718f, 0.26275063f, + 0.24053413f, 0.70197225f, 0.58496886f, 0.33996922f, 0.11154431f, 0.34257007f, 0.28898042f, 0.33729053f, + 0.048938513f, 0.60771453f, 0.13263822f, 0.11060041f, 0.091483414f, 0.70869184f, 0.19898665f, 0.29362458f, + 0.8919203f, 0.7654821f, 0.7866956f, 0.02524674f, 0.1414501f, 0.3112445f, 0.9130488f, 0.5511502f, + 0.12605143f, 0.5031309f, 0.11166459f, 0.39045036f, 0.36251247f, 0.9328308f, 0.65486836f, 0.41281444f, + 0.5844644f, 0.35566723f, 0.6964502f, 0.6977819f, 0.63427305f, 0.30511153f, 0.92657536f, 0.42781502f, + 0.30534166f, 0.813157f, 0.90752834f, 0.9975799f, 0.64812917f, 0.32955307f, 0.753946f, 0.92897725f, + 0.009582937f, 0.43805653f, 0.15901726f, 0.5931799f, 0.7067924f, 0.39670604f, 0.45817143f, 0.7250554f, + 0.41596514f, 0.08011025f, 0.900068f, 0.24834275f, 0.44507074f, 0.5471632f, 0.46995157f, 0.029657006f, + 0.7294f, 0.27288425f, 0.2406702f, 0.6194577f, 0.23906898f, 0.26892018f, 0.33152503f, 0.3121612f, + 0.29118127f, 0.36515707f, 0.6299379f, 0.095391035f, 0.19735986f, 0.5072957f, 0.56953406f, 0.77614623f, + 0.14877802f, 0.65959847f, 0.7841949f, 0.7776301f, 0.03428924f, 0.3091979f, 0.07021719f, 0.18359429f, + 0.77849144f, 0.42534047f, 0.7123557f, 0.20649683f, 0.57597995f, 0.19757104f, 0.749946f, 0.2813105f, + 0.37462044f, 0.06618434f, 0.50165176f, 0.9747401f, 0.7426891f, 0.23322952f, 0.50672436f, 0.44517577f, + 0.09746289f, 0.89204556f, 0.50806034f, 0.6052985f, 0.2980855f, 0.26604044f, 0.5824448f, 0.68485546f, + 0.612149f, 0.25902748f, 0.9854489f, 0.4263978f, 0.19379246f, 0.26614368f, 0.9922104f, 0.5000241f, + 0.4321279f, 0.2919191f, 0.3689273f, 0.078885734f, 0.10265827f, 0.79264474f, 0.9277247f, 0.9771502f, + 0.13902885f, 0.77043164f, 0.19051671f, 0.7982801f, 0.86077714f, 0.8869355f, 0.86002564f, 0.81278664f, + 0.5097318f, 0.7297412f, 0.32111454f, 0.7177174f, 0.33929902f, 0.49160433f, 0.064810574f, 0.3692627f, + 0.23706353f, 0.3313396f, 0.18070674f, 0.05027789f, 0.53255826f, 0.8244896f, 0.9553747f, 0.7917771f, + 0.24083132f, 0.005495131f, 0.6896569f, 0.78015697f, 0.07074398f, 0.67929304f, 0.9227386f, 0.5302883f, + 0.19877058f, 0.90993816f, 0.71350795f, 0.8311006f, 0.16185725f, 0.79097277f, 0.15846318f, 0.99474716f, + 0.28815013f, 0.80128354f, 0.6001208f, 0.63250524f, 0.4233225f, 0.7053677f, 0.29161406f, 0.028710365f, + 0.30789846f, 0.8917693f, 0.36836517f, 0.6571592f, 0.3151368f, 0.8750746f, 0.7992451f, 0.6765068f, + 0.24441916f, 0.091435075f, 0.5188247f, 0.20667112f, 0.9110969f, 0.019512117f, 0.72343415f, 0.998457f, + 0.7504142f, 0.6704894f, 0.01892668f, 0.9809466f, 0.41447622f, 0.032795787f, 0.9935814f, 0.29653466f, + 0.4646262f, 0.95763975f, 0.15339965f, 0.14625502f, 0.58130866f, 0.43307304f, 0.6151709f, 0.08064735f, + 0.5149533f, 0.27762014f, 0.25419557f, 0.04218155f, 0.7651092f, 0.59631824f, 0.077278376f, 0.89677596f, + 0.6508104f, 0.5927816f, 0.2064318f, 0.57540226f, 0.9817701f, 0.84294224f, 0.11056489f, 0.9564106f, + 0.5387549f, 0.74048257f, 0.88833815f, 0.9262546f, 0.11023259f, 0.93783194f, 0.16041255f, 0.53748304f, + 0.1506182f, 0.39038336f, 0.47727865f, 0.44018233f, 0.42101204f, 0.53943527f, 0.99320936f, 0.79050577f, + 0.77973497f, 0.7001237f, 0.88709056f, 0.4769255f, 0.5397561f, 0.60289854f, 0.06393474f, 0.09722155f, + 0.5613007f, 0.30437487f, 0.49082512f, 0.3852706f, 0.5778314f, 0.8253078f, 0.33417904f, 0.9004303f, + 0.8947809f, 0.11625093f, 0.11388689f, 0.09546256f, 0.22598988f, 0.30536187f, 0.46236527f, 0.3784039f, + 0.24737573f, 0.3411532f, 0.31912774f, 0.9905191f, 0.31468558f, 0.14199954f, 0.7078488f, 0.47111923f, + 0.882782f, 0.8124163f, 0.9593644f, 0.13382024f, 0.8214317f, 0.9196194f, 0.25308424f, 0.95958996f}; + const std::vector fc1_experts_bias = { + 0.8748215f, 0.5054756f, 0.74107623f, 0.32518923f, 0.0639081f, 0.62639004f, 0.64906263f, 0.17322052f, + 0.7424998f, 0.07288867f, 0.93031204f, 0.9841952f, 0.6361292f, 0.18628561f, 0.7433356f, 0.5852079f, + 0.6359594f, 0.66432667f, 0.88067776f, 0.28508204f, 0.38752747f, 0.63635296f, 0.55448055f, 0.9031888f, + 0.23738074f, 0.48179168f, 0.5934266f, 0.3672055f, 0.84085834f, 0.5546908f, 0.03788501f, 0.44583207f, + 0.27322155f, 0.5485856f, 0.44189203f, 0.00403291f, 0.40888733f, 0.45211035f, 0.35256076f, 0.9593902f, + 0.39090043f, 0.8212086f, 0.62385887f, 0.07793343f, 0.61749303f, 0.9143678f, 0.17294967f, 0.17681253f, + 0.9894245f, 0.901755f, 0.221053f, 0.8008725f, 0.43603396f, 0.007035315f, 0.5375667f, 0.661547f, + 0.35001957f, 0.67394173f, 0.072449565f, 0.84650797f, 0.92626715f, 0.77573335f, 0.58474565f, 0.66467446f}; + const std::vector fc2_experts_bias = { + 0.13822609f, 0.3750633f, 0.45226622f, 0.22175694f, 0.13068998f, 0.8363088f, 0.8393226f, 0.045905888f, + 0.65910596f, 0.7034011f, 0.97498417f, 0.78927684f, 0.95966834f, 0.33630514f, 0.8501932f, 0.9067007f, + 0.027835965f, 0.09864664f, 0.6012027f, 0.7730189f, 0.25159347f, 0.55506724f, 0.49927413f, 0.62655383f, + 0.23132521f, 0.7820195f, 0.8325047f, 0.15307087f, 0.5048437f, 0.5013873f, 0.66055787f, 0.96579224f}; + const std::vector output = { + 1.3775184f, 2.0985768f, 2.091839f, 2.9706357f, 1.9404914f, 1.9915576f, 2.3302228f, 2.3702593f, + 0.51896286f, 0.7936432f, 0.9944805f, 1.3225251f, 0.73894113f, 0.87975955f, 1.0468717f, 1.1585085f, + 0.012911659f, 0.045757107f, 0.27884653f, 0.3585817f, 0.116771236f, 0.25755364f, 0.23161705f, 0.2906256f, + 4.8571277f, 5.649453f, 5.485141f, 5.306299f, 4.767025f, 6.9010167f, 5.3520975f, 6.711155f}; + + RunMoETest(input, + router_probs, + fc1_experts_weights, + fc2_experts_weights, + fc1_experts_bias, + fc2_experts_bias, + output, + num_rows, + num_experts, + hidden_size, + inter_size, + "relu"); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index 84d8a9c56df89..9ab78cac3aca4 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -589,5 +589,30 @@ TEST(FunctionTest, TestInlinedLocalFunctionNotRemoved) { #endif } +TEST(FunctionTest, TestInlinedFunctionDoesNotReserrectNonExistingArgs) { + // Verify this runs + constexpr const ORTCHAR_T* model_uri = ORT_TSTR("testdata/transform/gh_issue_18338.onnx"); + + SessionOptions session_options; + InferenceSessionWrapper session_object{session_options, GetEnvironment()}; + + ASSERT_STATUS_OK(session_object.Load(model_uri)); + ASSERT_STATUS_OK(session_object.Initialize()); + + // Scalar shape for input_0 and output + const std::string input_names[] = {"input_0"}; + const std::string output_names[] = {"_val_3"}; + TensorShape input_shape; + MLFloat16 input_0_data{684.f}; + + OrtValue input_0; + Tensor::InitOrtValue(DataTypeImpl::GetType(), input_shape, &input_0_data, OrtMemoryInfo(), input_0); + + std::vector fetches(1); + RunOptions run_options; + ASSERT_STATUS_OK(session_object.Run(run_options, AsSpan(input_names), AsSpan({input_0}), + AsSpan(output_names), &fetches, 0)); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/mlas/bench/bench_q4gemm.cpp b/onnxruntime/test/mlas/bench/bench_q4gemm.cpp index cf02d4f3628f9..87e3601612761 100644 --- a/onnxruntime/test/mlas/bench/bench_q4gemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_q4gemm.cpp @@ -33,7 +33,7 @@ void Q4GEMM(benchmark::State& state, MLAS_BLK_QUANT_TYPE qtype) { auto B1 = RandomVectorUniform(static_cast(N * K), -1.0f, 1.0f); std::vector C1(static_cast(M * N)); - std::vector B1_packed(pack_b_size); + std::vector B1_packed(pack_b_size); MlasQ4GemmPackB(qtype, B1_packed.data(), B1.data(), N, K, N); MLAS_Q4_GEMM_DATA_PARAMS params1; diff --git a/onnxruntime/test/mlas/bench/bench_sgemm.cpp b/onnxruntime/test/mlas/bench/bench_sgemm.cpp index baa8f1a830ea1..e6e34bc88ad59 100644 --- a/onnxruntime/test/mlas/bench/bench_sgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sgemm.cpp @@ -128,7 +128,7 @@ BENCHMARK_CAPTURE(SGEMM, PACKB_TransA, true, true, false)->Apply(GemmSizeProduct static void GemmLLMSizeProducts(benchmark::internal::Benchmark* b) { b->ArgNames(sgemm_bench_arg_names); - ArgsProduct(b, {{1, 1024, 2048}, {4096}, {4096}}); + ArgsProduct(b, {{1, 1024, 2048}, {4096, 11008}, {4096, 11008}}); } BENCHMARK_CAPTURE(SGEMM, LLM, false, false, true)->Apply(GemmLLMSizeProducts)->UseRealTime(); diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp new file mode 100644 index 0000000000000..2f2635dab0512 --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "mlas_q4.h" +#include "mlas_qnbit.h" + +#include + +#include "benchmark/benchmark.h" + +#include "bench_util.h" +#include "core/util/thread_utils.h" + +template +void SQNBITGEMM(benchmark::State& state) { + if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); + if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); + if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); + if (state.range(3) <= 0) throw std::invalid_argument("Threads must greater than 0!"); + + const size_t M = static_cast(state.range(0)); + const size_t N = static_cast(state.range(1)); + const size_t K = static_cast(state.range(2)); + const size_t threads = static_cast(state.range(3)); + + size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; + MlasBlockwiseQuantizedBufferSizes( + BlkBitWidth, BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), + QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); + + OrtThreadPoolParams tpo; + tpo.thread_pool_size = static_cast(threads); + tpo.auto_set_affinity = true; + + std::unique_ptr tp( + onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), + tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); + + auto A = RandomVectorUniform(static_cast(M * K), -1.0f, 1.0f); + auto B = RandomVectorUniform(static_cast(K * N), -1.0f, 1.0f); + std::vector C(static_cast(M * N)); + + std::vector QuantBData(QuantBDataSizeInBytes); + std::vector QuantBScale(QuantBScaleSize); + std::vector QuantBZeroPoint(Symmetric ? 0 : QuantBZeroPointSizeInBytes); + + MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), + Symmetric ? nullptr : QuantBZeroPoint.data(), + B.data(), BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), static_cast(N), + tp.get()); + + MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; + params.A = A.data(); + params.lda = K; + params.QuantBData = QuantBData.data(); + params.QuantBScale = QuantBScale.data(); + params.QuantBZeroPoint = Symmetric ? nullptr : QuantBZeroPoint.data(); + params.Bias = nullptr; + params.C = C.data(); + params.ldc = N; + + // warm up run + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ¶ms, tp.get()); + + for (auto _ : state) { + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ¶ms, tp.get()); + } +} + +static void GemmSizeProducts(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K", "Threads"}); + ArgsProduct(b, {{1, 1024, 2048}, {4096, 11008}, {4096, 11008}, {8}}); +} + +BENCHMARK(SQNBITGEMM<4, 16, false>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 16, true>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 32, false>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 32, true>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 64, false>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 64, true>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 128, false>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 128, true>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 256, false>)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4, 256, true>)->Apply(GemmSizeProducts)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_blockq4.cpp b/onnxruntime/test/mlas/unittest/test_blockq4.cpp index f836da8277bb8..07f0748fb7ed1 100644 --- a/onnxruntime/test/mlas/unittest/test_blockq4.cpp +++ b/onnxruntime/test/mlas/unittest/test_blockq4.cpp @@ -38,13 +38,17 @@ class MlasBlockwiseQdqTest : public MlasTestBase { int meta_rows; int meta_cols; - MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); + MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); int q_rows; int q_cols; - MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); + MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); - uint8_t* elements = InputElements.GetBuffer(q_rows * q_cols, true); + size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; + MlasBlockwiseQuantizedBufferSizes(4, block_size, columnwise, rows, columns, + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); + + uint8_t* elements = InputElements.GetBuffer(q_data_size_in_bytes, true); int v = 7; for (int c = 0; c < columns; c++) { @@ -70,8 +74,8 @@ class MlasBlockwiseQdqTest : public MlasTestBase { } } - float* scales = InputScales.GetBuffer(meta_rows * meta_cols); - uint8_t* zp = symmetric ? nullptr : InputOffsets.GetBuffer(((meta_rows + 1) / 2) * meta_cols, true); + float* scales = InputScales.GetBuffer(q_scale_size); + uint8_t* zp = symmetric ? nullptr : InputOffsets.GetBuffer(q_zp_size_in_bytes, true); if (zp) { for (int c = 0; c < meta_cols; c++) { for (int r = 0; r < meta_rows; r += 2) { diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.h b/onnxruntime/test/mlas/unittest/test_halfgemm.h index 2861b0e746fdc..4db5c2bebca40 100644 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.h +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.h @@ -18,20 +18,6 @@ Module Name: #include "test_fp16.h" -inline bool -CloseEnough(float actual, float expected) { - if (std::isnan(actual)) { - return std::isnan(expected); - } - float diff = std::abs(actual - expected); - float top = std::max(std::abs(actual), std::abs(expected)); - float ratio = 0; - if (top > 0.0001) { - ratio = diff / top; - } - return ratio < 0.005; -} - /** * @brief Test class for half precision GEMM * @tparam AType Data type of A matrix, can be either float or MLFp16 diff --git a/onnxruntime/test/mlas/unittest/test_main.cpp b/onnxruntime/test/mlas/unittest/test_main.cpp index 66b5a6a15db2b..505c0c01dfa90 100644 --- a/onnxruntime/test/mlas/unittest/test_main.cpp +++ b/onnxruntime/test/mlas/unittest/test_main.cpp @@ -1,17 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "test_util.h" - -#include #include +#include +#include + +#include "test_util.h" #if !defined(BUILD_MLAS_NO_ONNXRUNTIME) MLAS_THREADPOOL* GetMlasThreadPool(void) { - static MLAS_THREADPOOL* threadpool = new onnxruntime::concurrency::ThreadPool( + static auto threadpool = std::make_unique( &onnxruntime::Env::Default(), onnxruntime::ThreadOptions(), nullptr, 2, true); - return threadpool; + return threadpool.get(); } #else diff --git a/onnxruntime/test/mlas/unittest/test_q4gemm.h b/onnxruntime/test/mlas/unittest/test_q4gemm.h index 58a64491ae80b..97c6969b5bf91 100644 --- a/onnxruntime/test/mlas/unittest/test_q4gemm.h +++ b/onnxruntime/test/mlas/unittest/test_q4gemm.h @@ -19,20 +19,6 @@ Module Name: #include "test_util.h" #include "mlas_q4.h" -inline bool -CloseEnough(float actual, float expected) { - if (std::isnan(actual)) { - return std::isnan(expected); - } - float diff = std::abs(actual - expected); - float top = std::max(std::abs(actual), std::abs(expected)); - float ratio = 0; - if (top > 0.0001) { - ratio = diff / top; - } - return ratio < 0.005; -} - /** * @brief Test class for int4 block quantized GEMM * Note: only 2-D matmul supported for now diff --git a/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp b/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp index a78a3261d1f2a..d3f601793a970 100644 --- a/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_q8q4gemm.cpp @@ -19,20 +19,6 @@ Module Name: #include "test_util.h" #include "mlas_q4.h" -inline bool -CloseEnough(float actual, float expected) { - if (std::isnan(actual)) { - return std::isnan(expected); - } - float diff = std::abs(actual - expected); - float top = std::max(std::abs(actual), std::abs(expected)); - float ratio = 0; - if (top > 0.0001) { - ratio = diff / top; - } - return ratio < 0.005; -} - template static void blkq8_dequant_reference(const int8_t* src, float* dst, size_t M, size_t K) { const size_t num_blks = K / QBlkLen; diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp new file mode 100644 index 0000000000000..6c97d60301573 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -0,0 +1,270 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_sqnbitgemm.h + +Abstract: + + Tests for MLAS n-bit int block quantized GEMM. + +--*/ + +#include "test_util.h" +#include "mlas_q4.h" +#include "mlas_qnbit.h" + +/** + * @brief Test class for n-bit int block quantized GEMM + * Note: only 2-D matmul supported for now + */ +template +class MlasSQNBitGemmTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferA; + MatrixGuardBuffer BufferB; + MatrixGuardBuffer BufferQuantBData; + MatrixGuardBuffer BufferQuantBZeroPoint; + MatrixGuardBuffer BufferQuantBScale; + MatrixGuardBuffer BufferDequantizedB; + MatrixGuardBuffer BufferBias; + MatrixGuardBuffer BufferC; + MatrixGuardBuffer BufferCReference; + + void CallGemm(size_t M, + size_t N, + size_t K, + const float* A, + size_t lda, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + const float* Bias, + float* C, + size_t ldc, + MLAS_THREADPOOL* Threadpool) { + MLAS_SQNBIT_GEMM_DATA_PARAMS params; + params.A = A; + params.lda = lda; + params.Bias = Bias; + params.C = C; + params.ldc = ldc; + params.QuantBData = QuantBData; + params.QuantBScale = QuantBScale; + params.QuantBZeroPoint = QuantBZeroPoint; + params.PostProcessor = nullptr; + + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ¶ms, Threadpool); + } + + void CallReferenceGemm(size_t M, + size_t N, + size_t K, + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + const float* Bias, + float* C) { + float* DequantizedBData = BufferDequantizedB.GetBuffer(K * N); + MlasDequantizeBlockwise( + DequantizedBData, QuantBData, QuantBScale, QuantBZeroPoint, BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), GetMlasThreadPool()); + // Note: DequantizedBData is in column major layout. + + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const float* a = A + m * K; + const float* b = DequantizedBData + n * K; + float* c = C + (m * N) + n; + + float sum = Bias == nullptr ? 0.0f : Bias[n]; + for (size_t k = 0; k < K; k++) { + sum += (*a) * (*b); + b += 1; + a += 1; + } + *c = sum; + } + } + } + + public: + void Test(size_t M, size_t N, size_t K, + bool WithBias, bool Symmetric, bool WithThreadpool) { + MLAS_THREADPOOL* Threadpool = WithThreadpool ? GetMlasThreadPool() : nullptr; + + const float* A = BufferA.GetBuffer(K * M); + + const float* B = BufferB.GetBuffer(N * K); + + const float* Bias = nullptr; + if (WithBias) { + Bias = BufferBias.GetBuffer(N); + } + +#if 0 + auto print_matrix = [](size_t ncols, size_t nrows, const float* data) { + for (size_t row = 0; row < nrows; ++row) { + for (size_t col = 0; col < ncols; ++col) { + std::cout << data[row * nrows + col] << "\t"; + } + std::cout << "\n"; + } + }; + + std::cout << "A:\n"; + print_matrix(M, K, A); + std::cout << "B:\n"; + print_matrix(K, N, B); +#endif + + float* C = BufferC.GetBuffer(N * M, true); + float* CReference = BufferCReference.GetBuffer(N * M, true); + + // pack B + uint8_t* QuantBData = nullptr; + float* QuantBScale = nullptr; + uint8_t* QuantBZeroPoint = nullptr; + { + size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; + MlasBlockwiseQuantizedBufferSizes(BlkBitWidth, BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), + QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); + + QuantBData = BufferQuantBData.GetBuffer(QuantBDataSizeInBytes); + QuantBScale = BufferQuantBScale.GetBuffer(QuantBScaleSize); + if (Symmetric) { + QuantBZeroPoint = BufferQuantBZeroPoint.GetBuffer(QuantBZeroPointSizeInBytes); + } + + MlasQuantizeBlockwise(QuantBData, QuantBScale, QuantBZeroPoint, + B, BlkLen, + /* columnwise */ true, + static_cast(K), static_cast(N), + static_cast(N), + GetMlasThreadPool()); + } + + CallGemm(M, N, K, A, /* lda */ K, QuantBData, QuantBScale, QuantBZeroPoint, Bias, C, /* ldc */ N, Threadpool); + CallReferenceGemm(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); + + size_t f = 0; + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++, f++) { + ASSERT_TRUE(CloseEnough(C[f], CReference[f])) + << "Expected: " << CReference[f] << " Actual: " << C[f] << "@[" << m << "x" << n << "], " + << "M=" << M << ", N=" << N << ", K=" << K; + } + } + } + + public: + static const char* GetTestSuiteName() { + static std::string suite_name = std::string("SQNBitGemm") + + "BlkBitWidth" + std::to_string(BlkBitWidth) + + "BlkLen" + std::to_string(BlkLen); + return suite_name.c_str(); + } +}; + +// +// Short Execute() test helper to register each test separately by all parameters. +// +template +class SQNBitGemmShortExecuteTest : public MlasTestFixture> { + public: + explicit SQNBitGemmShortExecuteTest(size_t M, size_t N, size_t K, + bool WithThreadpool, bool Symmetric, bool WithBias) + : M_(M), N_(N), K_(K), WithThreadpool_(WithThreadpool), Symmetric_(Symmetric), WithBias_(WithBias) { + } + + void TestBody() override { + MlasTestFixture>::mlas_tester->Test( + M_, N_, K_, WithThreadpool_, Symmetric_, WithBias_); + } + + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, + bool WithThreadpool, bool Symmetric, bool WithBias) { + std::stringstream ss; + ss << (WithThreadpool ? "SingleThread" : "Threaded") + << "/isSymmetric" << Symmetric + << "/M" << M << "xN" << N << "xK" << K + << "/hasBias" << WithBias; + auto test_name = ss.str(); + + testing::RegisterTest( + MlasSQNBitGemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture>* { + return new SQNBitGemmShortExecuteTest( + M, N, K, WithThreadpool, Symmetric, WithBias); + }); + + return 1; + } + + static size_t RegisterShortExecuteTests() { + size_t test_registered = 0; + + if (MlasIsSQNBitGemmAvailable(BlkBitWidth, BlkLen)) { + for (bool WithThreadpool : {false, true}) { + for (bool Symmetric : {false, true}) { + for (size_t b = 1; b < 16; b++) { + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, false); + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true); + } + for (size_t b = 16; b <= 256; b <<= 1) { + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, false); + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true); + } + for (size_t b = 256; b < 320; b += 32) { + test_registered += RegisterSingleTest(b, b, b, WithThreadpool, Symmetric, true); + } + for (size_t b = 1; b < 96; b++) { + test_registered += RegisterSingleTest(1, b, 32, WithThreadpool, Symmetric, false); + test_registered += RegisterSingleTest(1, 32, b, WithThreadpool, Symmetric, true); + test_registered += RegisterSingleTest(1, b, b, WithThreadpool, Symmetric, false); + } + test_registered += RegisterSingleTest(43, 500, 401, WithThreadpool, Symmetric, true); + + // test_registered += RegisterSingleTest(1001, 1027, 1031, WithThreadpool, Symmetric, false); + } + } + } + + return test_registered; + } + + private: + size_t M_, N_, K_; + bool WithThreadpool_, Symmetric_, WithBias_; +}; + +static size_t SQNBitGemmRegisterAllShortExecuteTests() { + size_t count = 0; + + count += SQNBitGemmShortExecuteTest<4, 16>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest<4, 32>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest<4, 64>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest<4, 128>::RegisterShortExecuteTests(); + count += SQNBitGemmShortExecuteTest<4, 256>::RegisterShortExecuteTests(); + + return count; +} + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + if (is_short_execute) { + return SQNBitGemmRegisterAllShortExecuteTests() > 0; + } + return false; +}); diff --git a/onnxruntime/test/mlas/unittest/test_util.h b/onnxruntime/test/mlas/unittest/test_util.h index db528ef7291cc..8eecda900ff27 100644 --- a/onnxruntime/test/mlas/unittest/test_util.h +++ b/onnxruntime/test/mlas/unittest/test_util.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -253,3 +254,16 @@ inline void ReorderInputNchw(const int64_t* input_shape, const float* S, float* D += spatial_count * nchwc_channel_count; } } + +inline bool CloseEnough(float actual, float expected) { + if (std::isnan(actual)) { + return std::isnan(expected); + } + float diff = std::abs(actual - expected); + float top = std::max(std::abs(actual), std::abs(expected)); + float ratio = 0; + if (top > 0.0001) { + ratio = diff / top; + } + return ratio < 0.005; +} diff --git a/onnxruntime/test/onnx/microbenchmark/reduceminmax.cc b/onnxruntime/test/onnx/microbenchmark/reduceminmax.cc index bd2abadf49b81..d866045ba4962 100644 --- a/onnxruntime/test/onnx/microbenchmark/reduceminmax.cc +++ b/onnxruntime/test/onnx/microbenchmark/reduceminmax.cc @@ -91,6 +91,8 @@ BENCHMARK(BM_FindMinMaxMlasSSE2) ->Arg(98304) ->Arg(160000); +#ifdef MLAS_TARGET_AMD64 + // MLAS avx implementation static void BM_FindMinMaxMlasAvx(benchmark::State& state) { const size_t batch_size = static_cast(state.range(0)); @@ -115,3 +117,5 @@ BENCHMARK(BM_FindMinMaxMlasAvx) ->Arg(80000) ->Arg(98304) ->Arg(160000); + +#endif // MLAS_TARGET_AMD64 diff --git a/onnxruntime/test/optimizer/compute_optimizer_test.cc b/onnxruntime/test/optimizer/compute_optimizer_test.cc index a03d0da2538d4..9dcedd1fd7681 100644 --- a/onnxruntime/test/optimizer/compute_optimizer_test.cc +++ b/onnxruntime/test/optimizer/compute_optimizer_test.cc @@ -847,7 +847,7 @@ Test graph includes multiple equivalent subgraphs as below. Add an Identity node because currently, we don't allow Gather generates graph output. */ TEST(ComputeOptimizerTests, GatherLayerNormalization) { - std::vector> test_config_pairs{ + std::vector> test_config_pairs{ // { // is_scalar_slice, // ln_axis_before_propagation, @@ -929,13 +929,6 @@ TEST(ComputeOptimizerTests, GatherLayerNormalization) { const ONNX_NAMESPACE::TensorShapeProto* slice_out_shape = producer_node->OutputDefs()[0]->Shape(); TEST_RETURN_IF_NOT(slice_out_shape != nullptr); - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - - auto& axis_attr = attrs.at("axis"); - auto axis_value = (int)axis_attr.i(); - TEST_RETURN_IF_NOT(axis_value == ln_axis_after); - if (is_scalar_slice) { TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 2); TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && @@ -951,10 +944,15 @@ TEST(ComputeOptimizerTests, GatherLayerNormalization) { TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(2)) && slice_out_shape->dim(2).dim_value() == 256); } - } else { TEST_RETURN_IF_NOT(producer_node == nullptr); } + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + TEST_RETURN_IF_NOT(axis_value == ln_axis_after); } } @@ -2841,165 +2839,110 @@ Test graph include multiple equivalent subgraphs as below. Add an Identity node because currently we don't allow Reshape generate graph output. */ -TEST(ComputeOptimizerTests, ReshapeLayerNormalization_PropagationOnOneBranch) { - const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); - auto pre_graph_checker = [](Graph& graph) -> Status { - auto op_count_pre = CountOpsInGraph(graph); - TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); - TEST_RETURN_IF_NOT(op_count_pre["LayerNormalization"] == 1); - TEST_RETURN_IF_NOT(op_count_pre["Reshape"] == 1); - TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); - return Status::OK(); - }; - - auto post_graph_checker = [](Graph& graph) { - auto op_count_post = CountOpsInGraph(graph); - TEST_RETURN_IF_NOT(op_count_post.size() == 3U); - TEST_RETURN_IF_NOT(op_count_post["LayerNormalization"] == 1); - TEST_RETURN_IF_NOT(op_count_post["Reshape"] == 1); - TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); - - for (Node& node : graph.Nodes()) { - if (node.OpType() == "LayerNormalization") { - const auto& input_defs = node.InputDefs(); - - { - auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); - TEST_RETURN_IF_NOT(producer_node != nullptr); - TEST_RETURN_IF_NOT(producer_node->OpType() == "Reshape"); - - InlinedVector values; - constexpr bool require_constant = true; - NodeArg* initializer_node_arg = graph.GetNodeArg(producer_node->InputDefs()[1]->Name()); - TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, values, require_constant)); - TEST_RETURN_IF_NOT(values.size() == 2); - TEST_RETURN_IF_NOT(values[0] == -1); - TEST_RETURN_IF_NOT(values[1] == 1024); - } - - { - auto producer_node = graph.GetProducerNode(input_defs[1]->Name()); - TEST_RETURN_IF_NOT(producer_node == nullptr); - } - - { - auto producer_node = graph.GetProducerNode(input_defs[2]->Name()); - TEST_RETURN_IF_NOT(producer_node == nullptr); - } - } - } - return Status::OK(); +TEST(ComputeOptimizerTests, ReshapeLayerNormalization) { + std::vector> test_config_pairs{ + // { + // ln_axis_before_propagation, + // expected_ln_axis_after_propagation, + // expected to propagate + // } + {0, 0, false}, + {1, 1, false}, + {2, 1, true}, + {-3, -3, false}, + {-2, -2, false}, + {-1, -1, true}, }; - std::vector fist_dim_values = {-1, 128}; - for (auto first_dim_value : fist_dim_values) { - auto build_test_case = [&first_dim_value](ModelTestBuilder& builder) { - auto* input1_arg = builder.MakeInput({{4, 32, 1024}}); - auto* input2_arg = builder.MakeInput({{1024}}); - auto* input3_arg = builder.MakeInput({{1024}}); - auto* ln_out = builder.MakeIntermediate(); - builder.AddNode("LayerNormalization", {input1_arg, input2_arg, input3_arg}, {ln_out}) - .AddAttribute("axis", static_cast(-1)); - - auto* shape_initializer = builder.MakeInitializer({2}, {first_dim_value, 1024}); - auto* reshape_out = builder.MakeIntermediate(); - builder.AddNode("Reshape", {ln_out, shape_initializer}, {reshape_out}); + for (auto p : test_config_pairs) { + int64_t ln_axis_before = std::get<0>(p); + int64_t ln_axis_after = std::get<1>(p); + bool expected_to_propagate = std::get<2>(p); - auto* identity_out = builder.MakeOutput(); - builder.AddNode("Identity", {reshape_out}, {identity_out}); + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto pre_graph_checker = [](Graph& graph) -> Status { + auto op_count_pre = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); + TEST_RETURN_IF_NOT(op_count_pre["LayerNormalization"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Reshape"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); + return Status::OK(); }; - const std::vector opsets{12, 13, 14}; - for (auto& opset_version : opsets) { - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger, std::move(transformer), - TransformerLevel::Level1, - 1, pre_graph_checker, post_graph_checker)); - } - } -} + auto post_graph_checker = [ln_axis_after, expected_to_propagate](Graph& graph) { + auto op_count_post = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_post.size() == 3U); + TEST_RETURN_IF_NOT(op_count_post["LayerNormalization"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Reshape"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); -/* -Test graph include multiple equivalent subgraphs as below. - graph input [4, 32, 1024] (float) graph input [1024] (float) graph input [1024] (float) - | | / - \_____________ _______/ __________________________/ - \ / / - LayerNormalization - | - Reshape - | - Identity - | - graph out [128, 1024] (float) + for (Node& node : graph.Nodes()) { + if (node.OpType() == "LayerNormalization") { + const auto& input_defs = node.InputDefs(); -Add an Identity node because currently we don't allow Reshape generate graph output. -*/ -TEST(ComputeOptimizerTests, ReshapeLayerNormalization_NoPropagation) { - const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); - auto pre_graph_checker = [](Graph& graph) -> Status { - auto op_count_pre = CountOpsInGraph(graph); - TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); - TEST_RETURN_IF_NOT(op_count_pre["LayerNormalization"] == 1); - TEST_RETURN_IF_NOT(op_count_pre["Reshape"] == 1); - TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); - return Status::OK(); - }; + if (expected_to_propagate) { + auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); + TEST_RETURN_IF_NOT(producer_node != nullptr); + TEST_RETURN_IF_NOT(producer_node->OpType() == "Reshape"); - auto post_graph_checker = [](Graph& graph) { - auto op_count_post = CountOpsInGraph(graph); - TEST_RETURN_IF_NOT(op_count_post.size() == 3U); - TEST_RETURN_IF_NOT(op_count_post["LayerNormalization"] == 1); - TEST_RETURN_IF_NOT(op_count_post["Reshape"] == 1); - TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); + InlinedVector values; + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(producer_node->InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, values, require_constant)); + TEST_RETURN_IF_NOT(values.size() == 2); + TEST_RETURN_IF_NOT(values[0] == -1); + TEST_RETURN_IF_NOT(values[1] == 1024); + } else { + auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); + TEST_RETURN_IF_NOT(producer_node == nullptr); + } - for (Node& node : graph.Nodes()) { - if (node.OpType() == "LayerNormalization") { - const auto& input_defs = node.InputDefs(); + { + auto producer_node = graph.GetProducerNode(input_defs[1]->Name()); + TEST_RETURN_IF_NOT(producer_node == nullptr); + } - { - auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); - TEST_RETURN_IF_NOT(producer_node == nullptr); - } + { + auto producer_node = graph.GetProducerNode(input_defs[2]->Name()); + TEST_RETURN_IF_NOT(producer_node == nullptr); + } - { - auto producer_node = graph.GetProducerNode(input_defs[1]->Name()); - TEST_RETURN_IF_NOT(producer_node == nullptr); - } + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - { - auto producer_node = graph.GetProducerNode(input_defs[2]->Name()); - TEST_RETURN_IF_NOT(producer_node == nullptr); + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + TEST_RETURN_IF_NOT(axis_value == ln_axis_after); } } - } - return Status::OK(); - }; - - std::vector fist_dim_values = {-1, 128}; - for (auto first_dim_value : fist_dim_values) { - auto build_test_case = [&first_dim_value](ModelTestBuilder& builder) { - auto* input1_arg = builder.MakeInput({{4, 32, 1024}}); - auto* input2_arg = builder.MakeInput({{1024}}); - auto* input3_arg = builder.MakeInput({{1024}}); - auto* ln_out = builder.MakeIntermediate(); - builder.AddNode("LayerNormalization", {input1_arg, input2_arg, input3_arg}, {ln_out}) - .AddAttribute("axis", static_cast(1)); - - auto* shape_initializer = builder.MakeInitializer({2}, {first_dim_value, 1024}); - auto* reshape_out = builder.MakeIntermediate(); - builder.AddNode("Reshape", {ln_out, shape_initializer}, {reshape_out}); - - auto* identity_out = builder.MakeOutput(); - builder.AddNode("Identity", {reshape_out}, {identity_out}); + return Status::OK(); }; - const std::vector opsets{12, 13, 14}; - for (auto& opset_version : opsets) { - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger, std::move(transformer), - TransformerLevel::Level1, - 1, pre_graph_checker, post_graph_checker)); + std::vector fist_dim_values = {-1, 128}; + for (auto first_dim_value : fist_dim_values) { + auto build_test_case = [ln_axis_before, &first_dim_value](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{4, 32, 1024}}); + auto* input2_arg = builder.MakeInput({{1024}}); + auto* input3_arg = builder.MakeInput({{1024}}); + auto* ln_out = builder.MakeIntermediate(); + builder.AddNode("LayerNormalization", {input1_arg, input2_arg, input3_arg}, {ln_out}) + .AddAttribute("axis", ln_axis_before); + + auto* shape_initializer = builder.MakeInitializer({2}, {first_dim_value, 1024}); + auto* reshape_out = builder.MakeIntermediate(); + builder.AddNode("Reshape", {ln_out, shape_initializer}, {reshape_out}); + + auto* identity_out = builder.MakeOutput(); + builder.AddNode("Identity", {reshape_out}, {identity_out}); + }; + + const std::vector opsets{12, 13, 14}; + for (auto& opset_version : opsets) { + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger, std::move(transformer), + TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); + } } } } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index b82f3345dfcd1..17b26ed7ca4ca 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -6910,7 +6910,10 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); }; - auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); }; + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; // OpSet-12 { @@ -6933,8 +6936,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-14 @@ -6962,8 +6965,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-18 @@ -6991,8 +6994,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } } @@ -7023,7 +7026,10 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); }; - auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); }; + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; // OpSet-12 { @@ -7042,8 +7048,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-14 @@ -7063,8 +7069,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-18 @@ -7084,13 +7090,180 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } +} + +TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Input) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{2, 3, 3, 3}}); + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); + auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); + auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(2)}); + auto* gather_out_1 = builder.MakeIntermediate(); + auto* gather_out_2 = builder.MakeIntermediate(); + auto* gather_out_3 = builder.MakeIntermediate(); + auto* transpose_out_1 = builder.MakeOutput(); + auto* transpose_out_2 = builder.MakeOutput(); + auto* transpose_out_3 = builder.MakeOutput(); + + builder.AddNode("Gather", {data_arg, gather_index_1}, {gather_out_1}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Gather", {data_arg, gather_index_2}, {gather_out_2}) + .AddAttribute("axis", static_cast(-2)); + builder.AddNode("Gather", {data_arg, gather_index_3}, {gather_out_3}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; + + // OpSet-12 + { + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axes") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axes").ints().at(0))); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } + + // OpSet-14 + { + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + const NodeArg& input_arg = *(node.InputDefs()[1]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = + graph_utils::GetConstantInitializer(graph, input_arg.Name()); + TEST_RETURN_IF_NOT(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); + TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } + + // OpSet-18 + { + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + const NodeArg& input_arg = *(node.InputDefs()[1]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = + graph_utils::GetConstantInitializer(graph, input_arg.Name()); + TEST_RETURN_IF_NOT(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); + TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } } +TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Initializer) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInitializer({2, 3, 3, 3}, std::vector(54)); + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); + auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); + auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(2)}); + auto* gather_out_1 = builder.MakeIntermediate(); + auto* gather_out_2 = builder.MakeIntermediate(); + auto* gather_out_3 = builder.MakeIntermediate(); + auto* transpose_out_1 = builder.MakeOutput(); + auto* transpose_out_2 = builder.MakeOutput(); + auto* transpose_out_3 = builder.MakeOutput(); + + builder.AddNode("Gather", {data_arg, gather_index_1}, {gather_out_1}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Gather", {data_arg, gather_index_2}, {gather_out_2}) + .AddAttribute("axis", static_cast(-2)); + builder.AddNode("Gather", {data_arg, gather_index_3}, {gather_out_3}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + const NodeArg& input_arg = *(node.InputDefs()[1]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); + TEST_RETURN_IF_NOT(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); + TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { - auto pre_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); return Status::OK(); }; + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + return Status::OK(); + }; auto post_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0); @@ -7130,8 +7303,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // Invalid Gather indices. @@ -7166,8 +7339,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // Invalid Gather axis. @@ -7202,8 +7375,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } } @@ -7250,8 +7423,8 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } // OpSet-14, Tind is int64. @@ -7289,8 +7462,8 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { }; std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } } diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index 3073dde9d8e4c..3da3dc858175b 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -142,11 +142,6 @@ TEST_F(QnnHTPBackendTests, MatMulOp_HTP_u8) { } // Test QDQ MatMul with 16-bit act, 8-bit weights (static) -// TODO: (SLIGHT) Inaccuracy detected for output 'output', element 0. -// Output quant params: scale=0.0015259021893143654, zero_point=0. -// Expected val: 98 -// QNN QDQ val: 97.720298767089844 (err 0.27970123291015625) -// CPU QDQ val: 97.726402282714844 (err 0.27359771728515625) TEST_F(QnnHTPBackendTests, MatMulOp_HTP_A16_W8Static) { std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; @@ -158,6 +153,40 @@ TEST_F(QnnHTPBackendTests, MatMulOp_HTP_A16_W8Static) { 7e-3f); } +// Test QDQ MatMul with uint16 activation uint16 weights, both dynamic +// Inaccuracy detected for output 'output_0', element 1. +// Output quant params: scale=0.0015259021893143654, zero_point=0. +// Expected val: 40 +// QNN QDQ val: 39.681087493896484 (err 0.31891250610351562) +// CPU QDQ val: 39.99847412109375 (err 0.00152587890625) +TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_HTP_A16_W16Dynamic) { + std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; + std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; + RunQDQMatMulOpOpTest(TestInputDef({2, 3}, false, input0_data), + TestInputDef({3, 2}, false, input1_data), + ExpectedEPNodeAssignment::All, + 18, + true, // Use com.microsoft Q/DQ ops + 7e-3f); +} + +// Test QDQ MatMul with uint16 activation uint16 weights, both dynamic +// Inaccuracy detected for output 'output_0', element 1. +// Output quant params: scale=0.71908456087112427, zero_point=1. +// Expected val: 46848.41015625 +// QNN QDQ val: 46844.04296875 (err 4.3671875) +// CPU QDQ val: 46848.359375 (err 0.05078125) +TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_HTP_A16_W16DynamicLarge) { + std::vector input0_data = GetFloatDataInRange(-10.0f, 10.0f, 12 * 96 * 512); + std::vector input1_data = GetFloatDataInRange(-10.0f, 10.0f, 12 * 96 * 512); + RunQDQMatMulOpOpTest(TestInputDef({1, 12, 96, 512}, false, input0_data), + TestInputDef({1, 12, 512, 96}, false, input1_data), + ExpectedEPNodeAssignment::All, + 18, + true, // Use com.microsoft Q/DQ ops + 7e-3f); +} + // Test 16-bit QDQ MatMul with static weights // TODO: Inaccuracy detected for output 'output', element 0. // Output quant params: scale=0.0015259021893143654, zero_point=0. diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py deleted file mode 100644 index 4cf2e5d7f7588..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py +++ /dev/null @@ -1,1026 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import copy -import os -import unittest - -import numpy as np -import onnx -import torch -import torch.nn as nn -import torch.nn.functional as F -from helper import get_name -from numpy.testing import assert_allclose -from torchvision import datasets, transforms - -import onnxruntime -from onnxruntime.capi.ort_trainer import ( - IODescription, - LossScaler, - ModelDescription, - ORTTrainer, - generate_sample, - load_checkpoint, - save_checkpoint, -) - -SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) - - -def ort_trainer_learning_rate_description(): - return IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ) - - -def remove_extra_info(model_desc): - simple_model_desc = copy.deepcopy(model_desc) - for input_desc in simple_model_desc.inputs_: - input_desc.dtype_ = None - input_desc.num_classes_ = None - for output_desc in simple_model_desc.outputs_: - output_desc.dtype_ = None - output_desc.num_classes_ = None - return simple_model_desc - - -def bert_model_description(): - vocab_size = 30528 - input_ids_desc = IODescription( - "input_ids", - ["batch", "max_seq_len_in_batch"], - torch.int64, - num_classes=vocab_size, - ) - segment_ids_desc = IODescription("segment_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2) - input_mask_desc = IODescription("input_mask", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2) - masked_lm_labels_desc = IODescription( - "masked_lm_labels", - ["batch", "max_seq_len_in_batch"], - torch.int64, - num_classes=vocab_size, - ) - next_sentence_labels_desc = IODescription( - "next_sentence_labels", - [ - "batch", - ], - torch.int64, - num_classes=2, - ) - loss_desc = IODescription("loss", [], torch.float32) - - return ModelDescription( - [ - input_ids_desc, - segment_ids_desc, - input_mask_desc, - masked_lm_labels_desc, - next_sentence_labels_desc, - ], - [loss_desc], - ) - - -def map_optimizer_attributes(name): - no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] - no_decay = any(no_decay_key in name for no_decay_key in no_decay_keys) - if no_decay: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} - else: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6} - - -def generate_sample_batch(desc, batch_size, device): - desc_ = copy.deepcopy(desc) - desc_.shape_[0] = batch_size - sample = generate_sample(desc_, device) - return sample - - -def create_ort_trainer( - gradient_accumulation_steps, - use_mixed_precision, - allreduce_post_accumulation, - use_simple_model_desc=True, - loss_scaler=None, - deepspeed_zero_stage=0, -): - model_desc = bert_model_description() - simple_model_desc = remove_extra_info(model_desc) if use_simple_model_desc else model_desc - learning_rate_description = ort_trainer_learning_rate_description() - device = torch.device("cuda", 0) - - onnx_model = onnx.load(get_name("bert_toy_postprocessed.onnx")) - - model = ORTTrainer( - onnx_model, - None, - simple_model_desc, - "LambOptimizer", - map_optimizer_attributes, - learning_rate_description, - device, - gradient_accumulation_steps=gradient_accumulation_steps, - world_rank=0, - world_size=1, - loss_scaler=loss_scaler, - use_mixed_precision=use_mixed_precision, - allreduce_post_accumulation=allreduce_post_accumulation, - deepspeed_zero_stage=deepspeed_zero_stage, - ) - - return model, model_desc, device - - -def run_bert_training_test( - gradient_accumulation_steps, - use_mixed_precision, - allreduce_post_accumulation, - use_simple_model_desc=True, - use_internel_loss_scale=False, -): - torch.manual_seed(1) - onnxruntime.set_seed(1) - - loss_scaler = LossScaler("ort_test_input_loss_scalar", True) if use_internel_loss_scale else None - - model, model_desc, device = create_ort_trainer( - gradient_accumulation_steps, - use_mixed_precision, - allreduce_post_accumulation, - use_simple_model_desc, - loss_scaler, - ) - - if loss_scaler is None: - loss_scaler = LossScaler(model.loss_scale_input_name, True) - - input_ids_batches = [] - segment_ids_batches = [] - input_mask_batches = [] - masked_lm_labels_batches = [] - next_sentence_labels_batches = [] - batch_size = 16 - num_batches = 8 - for _batch in range(num_batches): - input_ids_batches = [ - *input_ids_batches, - generate_sample_batch(model_desc.inputs_[0], batch_size, device), - ] - segment_ids_batches = [ - *segment_ids_batches, - generate_sample_batch(model_desc.inputs_[1], batch_size, device), - ] - input_mask_batches = [ - *input_mask_batches, - generate_sample_batch(model_desc.inputs_[2], batch_size, device), - ] - masked_lm_labels_batches = [ - *masked_lm_labels_batches, - generate_sample_batch(model_desc.inputs_[3], batch_size, device), - ] - next_sentence_labels_batches = [ - *next_sentence_labels_batches, - generate_sample_batch(model_desc.inputs_[4], batch_size, device), - ] - - lr_batch_list = [ - 0.0000000e00, - 4.6012269e-07, - 9.2024538e-07, - 1.3803681e-06, - 1.8404908e-06, - 2.3006135e-06, - 2.7607362e-06, - 3.2208588e-06, - 3.6809815e-06, - ] - - actual_losses = [] - actual_all_finites = [] - - for batch_count in range(num_batches): - input_ids = generate_sample_batch(model_desc.inputs_[0], batch_size, device) - segment_ids = generate_sample_batch(model_desc.inputs_[1], batch_size, device) - input_mask = generate_sample_batch(model_desc.inputs_[2], batch_size, device) - masked_lm_labels = generate_sample_batch(model_desc.inputs_[3], batch_size, device) - next_sentence_labels = generate_sample_batch(model_desc.inputs_[4], batch_size, device) - lr = lr_batch_list[batch_count] - - learning_rate = torch.tensor([lr]).to(device) - training_args = [ - input_ids, - segment_ids, - input_mask, - masked_lm_labels, - next_sentence_labels, - learning_rate, - ] - if use_mixed_precision: - if not use_internel_loss_scale: - loss_scale = torch.tensor([loss_scaler.loss_scale_]).to(device) - training_args.append(loss_scale) - actual_loss = model.train_step(*training_args) - if isinstance(actual_loss, (list, tuple)): - assert len(actual_loss) == 2 - actual_loss, actual_all_finite = actual_loss - if not use_internel_loss_scale: - loss_scaler.update_loss_scale(actual_all_finite.item()) - actual_all_finites = [ - *actual_all_finites, - actual_all_finite.cpu().numpy().item(0), - ] - - actual_losses = [*actual_losses, actual_loss.cpu().numpy().item(0)] - else: - loss = model(*training_args) - actual_losses = [*actual_losses, loss.cpu().numpy().item(0)] - - if batch_count == num_batches - 1: - # test eval_step api with fetches at the end of the training. - # if eval_step is called during the training, it will affect the actual training loss (training session is stateful). - eval_loss = model.eval_step( - input_ids, - segment_ids, - input_mask, - masked_lm_labels, - next_sentence_labels, - fetches=["loss"], - ) - eval_loss = eval_loss.cpu().numpy().item(0) - - # If using internal loss scale, all_finites are handled internally too. - if use_mixed_precision and not use_internel_loss_scale: - return actual_losses, actual_all_finites, eval_loss - else: - return actual_losses, eval_loss - - -class MNISTWrapper: - class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - self.register_buffer("bias_buffer", torch.tensor(1e-6)) - - def forward(self, x): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - out = torch.add(out, self.bias_buffer.to(out.dtype)) - return out - - class NeuralNetWithLoss(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, x, target): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - return F.nll_loss(F.log_softmax(out, dim=1), target), out - - def my_loss(x, target): # noqa: N805 - return F.nll_loss(F.log_softmax(x, dim=1), target) - - def train_with_trainer(self, learningRate, trainer, device, train_loader, epoch): - actual_losses = [] - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - args_log_interval = 100 - if batch_idx % args_log_interval == 0: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, - batch_idx * len(data), - len(train_loader.dataset), - 100.0 * batch_idx / len(train_loader), - loss.item(), - ) - ) - actual_losses = [*actual_losses, loss.cpu().numpy().item()] - - return actual_losses - - # TODO: comple this once ORT training can do evaluation. - def test_with_trainer(self, trainer, device, test_loader): - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - output = F.log_softmax(trainer.eval_step((data), fetches=["probability"]), dim=1) - test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print( - "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( - test_loss, - correct, - len(test_loader.dataset), - 100.0 * correct / len(test_loader.dataset), - ) - ) - - return test_loss, correct / len(test_loader.dataset) - - def mnist_model_description(): - input_desc = IODescription("input1", ["batch", 784], torch.float32) - label_desc = IODescription( - "label", - [ - "batch", - ], - torch.int64, - num_classes=10, - ) - loss_desc = IODescription("loss", [], torch.float32) - probability_desc = IODescription("probability", ["batch", 10], torch.float32) - return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc]) - - def get_loaders(self): - args_batch_size = 64 - args_test_batch_size = 1000 - - kwargs = {"num_workers": 0, "pin_memory": True} - # set shuffle to False to get deterministic data set among different torch version - train_loader = torch.utils.data.DataLoader( - datasets.MNIST( - os.path.join(SCRIPT_DIR, "data"), - train=True, - download=True, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args_batch_size, - shuffle=False, - **kwargs, - ) - test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - os.path.join(SCRIPT_DIR, "data"), - train=False, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args_test_batch_size, - shuffle=False, - **kwargs, - ) - - return train_loader, test_loader - - def get_model(self): - input_size = 784 - hidden_size = 500 - num_classes = 10 - - # warning: changes the pytorch random generator state - model = MNISTWrapper.NeuralNet(input_size, hidden_size, num_classes) - model_desc = MNISTWrapper.mnist_model_description() - return model, model_desc - - def get_model_with_internal_loss(self): - input_size = 784 - hidden_size = 500 - num_classes = 10 - - # warning: changes the pytorch random generator state - model = MNISTWrapper.NeuralNetWithLoss(input_size, hidden_size, num_classes) - model_desc = MNISTWrapper.mnist_model_description() - return model, model_desc - - def get_trainer( - self, - model, - model_desc, - device, - onnx_opset_ver=12, - frozen_weights=[], # noqa: B006 - internal_loss_fn=False, - get_lr_this_step=None, - optimizer="SGDOptimizer", - ): - loss_fn = MNISTWrapper.my_loss if not internal_loss_fn else None - return ORTTrainer( - model, - loss_fn, - model_desc, - optimizer, - None, - IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ), - device, - _opset_version=onnx_opset_ver, - frozen_weights=frozen_weights, - get_lr_this_step=get_lr_this_step, - ) - - -class TestOrtTrainer(unittest.TestCase): - def run_mnist_training_and_testing(onnx_opset_ver): # noqa: N805 - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - trainer = mnist.get_trainer(model, model_desc, device, onnx_opset_ver=onnx_opset_ver) - - learningRate = 0.01 # noqa: N806 - args_epochs = 2 - expected_losses = [ - 2.312044143676758, - 0.8018650412559509, - 0.5819257497787476, - 0.47025489807128906, - 0.35800155997276306, - 0.41124576330184937, - 0.2731882333755493, - 0.4201386570930481, - 0.39458805322647095, - 0.38380366563796997, - 0.2722422480583191, - 0.24230478703975677, - 0.23505745828151703, - 0.33442264795303345, - 0.21140924096107483, - 0.31545233726501465, - 0.18556523323059082, - 0.3453553020954132, - 0.29598352313041687, - 0.3595045208930969, - ] - - expected_test_losses = [0.3145490005493164, 0.256188737487793] - expected_test_accuracies = [0.9075, 0.9265] - - actual_losses = [] - actual_test_losses, actual_accuracies = [], [] - for epoch in range(1, args_epochs + 1): - actual_losses = [ - *actual_losses, - *mnist.train_with_trainer(learningRate, trainer, device, train_loader, epoch), - ] - - test_loss, accuracy = mnist.test_with_trainer(trainer, device, test_loader) - actual_test_losses = [*actual_test_losses, test_loss] - actual_accuracies = [*actual_accuracies, accuracy] - - # if you update outcomes, also do so for resume from checkpoint test - # args_checkpoint_epoch = 1 - # if epoch == args_checkpoint_epoch: - # state = {'rng_state': torch.get_rng_state(), 'model': trainer.state_dict()} - # torch.save(state, get_name("ckpt_mnist.pt")) - - print("actual_losses=", actual_losses) - print("actual_test_losses=", actual_test_losses) - print("actual_accuracies=", actual_accuracies) - - # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs - # import pdb; pdb.set_trace() - rtol = 1e-03 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_test_losses, - actual_test_losses, - rtol=rtol, - err_msg="test loss mismatch", - ) - assert_allclose( - expected_test_accuracies, - actual_accuracies, - rtol=rtol, - err_msg="test accuracy mismatch", - ) - - def test_mnist_training_and_testing_opset12(self): - TestOrtTrainer.run_mnist_training_and_testing(onnx_opset_ver=12) - - def test_mnist_resume_training_and_testing(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - learningRate = 0.01 # noqa: N806 - args_epochs = 2 - args_checkpoint_epoch = 1 - # should match those in test without checkpointing - expected_losses = [ - 0.26509523391723633, - 0.24135658144950867, - 0.2397943139076233, - 0.3351520597934723, - 0.20998981595039368, - 0.31488314270973206, - 0.18481917679309845, - 0.34727591276168823, - 0.2971782684326172, - 0.3609251379966736, - ] - - expected_test_losses = [0.25632242965698243] - expected_test_accuracies = [0.9264] - - actual_losses = [] - actual_test_losses, actual_accuracies = [], [] - - # restore from checkpoint - resume_trainer = mnist.get_trainer(model, model_desc, device) - checkpoint = torch.load(get_name("ckpt_mnist.pt"), map_location="cpu") - torch.set_rng_state(checkpoint["rng_state"]) - resume_trainer.load_state_dict(checkpoint["model"], strict=True) - - # continue .. - for epoch in range(args_checkpoint_epoch + 1, args_epochs + 1): - actual_losses = [ - *actual_losses, - *mnist.train_with_trainer(learningRate, resume_trainer, device, train_loader, epoch), - ] - - test_loss, accuracy = mnist.test_with_trainer(resume_trainer, device, test_loader) - actual_test_losses = [*actual_test_losses, test_loss] - actual_accuracies = [*actual_accuracies, accuracy] - - print("actual_losses=", actual_losses) - print("actual_test_losses=", actual_test_losses) - print("actual_accuracies=", actual_accuracies) - - # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs - # import pdb; pdb.set_trace() - rtol = 1e-03 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_test_losses, - actual_test_losses, - rtol=rtol, - err_msg="test loss mismatch", - ) - assert_allclose( - expected_test_accuracies, - actual_accuracies, - rtol=rtol, - err_msg="test accuracy mismatch", - ) - - def test_mnist_state_dict(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device) - state_dict = trainer.state_dict() - assert state_dict == {} - - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - state_dict = trainer.state_dict() - assert state_dict.keys() == { - "fc1.bias", - "fc1.weight", - "fc2.bias", - "fc2.weight", - "bias_buffer", - } - - def test_mnist_save_as_onnx(self): - torch.manual_seed(1) - device = torch.device("cuda") - onnx_file_name = "mnist.onnx" - if os.path.exists(onnx_file_name): - os.remove(onnx_file_name) - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device) - trainer.save_as_onnx(onnx_file_name) - assert not os.path.exists(onnx_file_name) - - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - trainer.save_as_onnx(onnx_file_name) - assert os.path.exists(onnx_file_name) - - def test_mnist_device(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - for model_device in [torch.device("cpu"), torch.device("cuda")]: - model.to(model_device) - trainer = mnist.get_trainer(model, model_desc, device) - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - def test_mnist_initializer_names(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device) - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - assert ({n.name for n in trainer.onnx_model_.graph.initializer} - {"bias_buffer"}) == { - n for n, t in model.named_parameters() - } - - def test_mnist_initializer_names_with_internal_loss(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model_with_internal_loss() - - def get_lr_this_step(global_step): - learningRate = 0.02 # noqa: N806 - return torch.tensor([learningRate]) - - trainer = mnist.get_trainer( - model, - model_desc, - device, - internal_loss_fn=True, - get_lr_this_step=get_lr_this_step, - ) - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target) - - assert {n.name for n in trainer.onnx_model_.graph.initializer} == {n for n, t in model.named_parameters()} - - def test_mnist_frozen_weight(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device, frozen_weights=["fc1.weight"]) - - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - fc1_trainstep_1 = trainer.state_dict()["fc1.weight"] - fc2_trainstep_1 = trainer.state_dict()["fc2.weight"] - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - fc1_trainstep_2 = trainer.state_dict()["fc1.weight"] - fc2_trainstep_2 = trainer.state_dict()["fc2.weight"] - assert np.array_equal(fc1_trainstep_1, fc1_trainstep_2) and not np.array_equal(fc2_trainstep_1, fc2_trainstep_2) - - def test_mnist_torch_buffer(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device) - - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - fc1_trainstep_1 = trainer.state_dict()["fc1.weight"] - bias_buffer_trainstep_1 = trainer.state_dict()["bias_buffer"] - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - fc1_trainstep_2 = trainer.state_dict()["fc1.weight"] - bias_buffer_trainstep_2 = trainer.state_dict()["bias_buffer"] - assert not np.array_equal(fc1_trainstep_1, fc1_trainstep_2) and np.array_equal( - bias_buffer_trainstep_1, bias_buffer_trainstep_2 - ) - - def test_mnist_frozen_weight_checkpoint(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device, frozen_weights=["fc1.weight"]) - - learningRate = 0.02 # noqa: N806 - - # do one train step - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - # do one eval step - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.eval_step(data, target) - - # save checkpoint, load model and compare - state_dict = trainer.state_dict() - - new_model, _ = mnist.get_model() - trainer = mnist.get_trainer(new_model, model_desc, device, frozen_weights=["fc1.weight"]) - trainer.load_state_dict(state_dict) - - ckpt_loss, _ = trainer.eval_step(data, target) - assert loss == ckpt_loss - - loaded_state_dict = trainer.state_dict() - assert state_dict.keys() == loaded_state_dict.keys() - - def test_mnist_training_checkpoint(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer( - model, - model_desc, - device, - optimizer="LambOptimizer", - frozen_weights=["fc1.weight"], - ) - - learningRate = 0.02 # noqa: N806 - - # do 5 train step - for _i in range(5): - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - # do one eval step - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.eval_step(data, target) - - # save checkpoint, load model and compare - state_dict = trainer.state_dict() - - new_model, _ = mnist.get_model() - trainer = mnist.get_trainer( - new_model, - model_desc, - device, - optimizer="LambOptimizer", - frozen_weights=["fc1.weight"], - ) - trainer.load_state_dict(state_dict) - - ckpt_loss, _ = trainer.eval_step(data, target) - assert loss == ckpt_loss - - loaded_state_dict = trainer.state_dict() - assert state_dict.keys() == loaded_state_dict.keys() - for key in state_dict: - assert np.array_equal(state_dict[key], loaded_state_dict[key]) - - def test_bert_training_basic(self): - expected_losses = [ - 11.027887, - 11.108191, - 11.055356, - 11.040912, - 10.960277, - 11.02691, - 11.082471, - 10.920979, - ] - expected_eval_loss = [10.958977] - actual_losses, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=1, - use_mixed_precision=False, - allreduce_post_accumulation=False, - ) - - # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs - # print('losses expected: ', expected_losses) - # print('losses actual: ', actual_losses) - # print('eval_loss expected: ', expected_eval_loss) - # print('eval_loss actual: ', actual_eval_loss) - # import pdb; pdb.set_trace() - - rtol = 1e-03 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - def test_bert_training_gradient_accumulation(self): - expected_losses = [ - 11.027887, - 11.108191, - 11.055354, - 11.040904, - 10.960266, - 11.026897, - 11.082475, - 10.920998, - ] - expected_eval_loss = [10.958998] - - actual_losses, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=4, - use_mixed_precision=False, - allreduce_post_accumulation=False, - ) - - # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs - # print('losses expected: ', expected_losses) - # print('losses actual: ', actual_losses) - # print('eval_loss expected: ', expected_eval_loss) - # print('eval_loss actual: ', actual_eval_loss) - # import pdb; pdb.set_trace() - - rtol = 1e-03 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - def test_bert_checkpointing_basic(self): - model, _, _ = create_ort_trainer( - gradient_accumulation_steps=1, - use_mixed_precision=False, - allreduce_post_accumulation=True, - use_simple_model_desc=True, - loss_scaler=None, - ) - sd = model.state_dict() - - # modify one of the default values - sd["bert.encoder.layer.0.attention.output.LayerNorm.weight"] += 1 - model.load_state_dict(sd) - - ckpt_dir = "testdata" - save_checkpoint(model, ckpt_dir, "bert_toy_save_test") - del model - - # create new model - model2, _, _ = create_ort_trainer( - gradient_accumulation_steps=1, - use_mixed_precision=False, - allreduce_post_accumulation=True, - use_simple_model_desc=True, - loss_scaler=None, - ) - - # load changed checkpoint - load_checkpoint(model2, ckpt_dir, "bert_toy_save_test") - loaded_sd = model2.state_dict() - - for k, v in loaded_sd.items(): - assert torch.all(torch.eq(v, sd[k])) - - def test_wrap_model_loss_fn_state_dict(self): - torch.manual_seed(1) - device = torch.device("cuda") - - class LinearModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 4) - - def forward(self, y=None, x=None): - if y is not None: - return self.linear(x) + y - else: - return self.linear(x) + torch.ones(2, 4) - - pt_model = LinearModel() - data = torch.randn(2, 2) - label = torch.tensor([0, 1], dtype=torch.int64) - input_desc = IODescription("x", [2, 2], torch.float32) - label_desc = IODescription( - "label", - [ - 2, - ], - torch.int64, - num_classes=4, - ) - output_desc = IODescription("output", [2, 4], torch.float32) - loss_desc = IODescription("loss", [], torch.float32) - model_desc = ModelDescription([input_desc, label_desc], [loss_desc, output_desc]) - - def loss_fn(x, label): - return F.nll_loss(F.log_softmax(x, dim=1), label) - - def get_lr_this_step(global_step): - learningRate = 0.02 # noqa: N806 - return torch.tensor([learningRate]) - - ort_trainer = ORTTrainer( - pt_model, - loss_fn, - model_desc, - "SGDOptimizer", - None, - IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ), - device, - get_lr_this_step=get_lr_this_step, - ) - ort_trainer.train_step(x=data, label=label) - state_dict = ort_trainer.state_dict() - assert state_dict.keys() == {"linear.bias", "linear.weight"} - - -if __name__ == "__main__": - unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py deleted file mode 100644 index 3b994e6f26710..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import unittest - -from numpy.testing import assert_allclose, assert_array_equal -from onnxruntime_test_ort_trainer import run_bert_training_test - - -class TestOrtTrainer(unittest.TestCase): - def test_bert_training_mixed_precision(self): - expected_losses = [ - 11.034248352050781, - 11.125300407409668, - 11.006105422973633, - 11.047048568725586, - 11.027417182922363, - 11.015759468078613, - 11.060905456542969, - 10.971782684326172, - ] - expected_all_finites = [True, True, True, True, True, True, True, True] - expected_eval_loss = [10.959012985229492] - actual_losses, actual_all_finites, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=1, - use_mixed_precision=True, - allreduce_post_accumulation=False, - use_simple_model_desc=False, - ) - - rtol = 1e-02 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - def test_bert_training_mixed_precision_internal_loss_scale(self): - expected_losses = [ - 11.034248352050781, - 11.125300407409668, - 11.006105422973633, - 11.047048568725586, - 11.027417182922363, - 11.015759468078613, - 11.060905456542969, - 10.971782684326172, - ] - expected_eval_loss = [10.959012985229492] - actual_losses, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=1, - use_mixed_precision=True, - allreduce_post_accumulation=False, - use_simple_model_desc=False, - use_internel_loss_scale=True, - ) - - rtol = 1e-02 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - def test_bert_training_gradient_accumulation_mixed_precision(self): - expected_losses = [ - 11.034248352050781, - 11.125300407409668, - 11.006077766418457, - 11.047025680541992, - 11.027434349060059, - 11.0156831741333, - 11.060973167419434, - 10.971841812133789, - ] - expected_all_finites = [True, True] - expected_eval_loss = [10.95903205871582] - actual_losses, actual_all_finites, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=4, - use_mixed_precision=True, - allreduce_post_accumulation=False, - use_simple_model_desc=False, - ) - - rtol = 1e-02 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - -if __name__ == "__main__": - unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py b/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py deleted file mode 100644 index 540f39b797bdb..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import unittest - -import torch -import torch.nn as nn -from numpy.testing import assert_allclose -from onnxruntime_test_ort_trainer import map_optimizer_attributes, ort_trainer_learning_rate_description -from onnxruntime_test_training_unittest_utils import process_dropout - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer - - -class TestTrainingDropout(unittest.TestCase): - def setUp(self): - torch.manual_seed(1) - onnxruntime.set_seed(1) - - @unittest.skip( - "Temporarily disable this test. The graph below will trigger ORT to " - "sort backward graph before forward graph which gives incorrect result. " - "https://github.com/microsoft/onnxruntime/issues/16801" - ) - def test_training_and_eval_dropout(self): - class TwoDropoutNet(nn.Module): - def __init__(self, drop_prb_1, drop_prb_2, dim_size): - super().__init__() - self.drop_1 = nn.Dropout(drop_prb_1) - self.drop_2 = nn.Dropout(drop_prb_2) - self.weight_1 = torch.nn.Parameter(torch.zeros(dim_size, dtype=torch.float32)) - - def forward(self, x): - x = x + self.weight_1 - x = self.drop_1(x) - x = self.drop_2(x) - output = x - return output[0] - - dim_size = 3 - device = torch.device("cuda", 0) - # This will drop all values, therefore expecting all 0 in output tensor - model = TwoDropoutNet(0.999, 0.999, dim_size) - input_desc = IODescription("input", [dim_size], torch.float32) - output_desc = IODescription("output", [], torch.float32) - model_desc = ModelDescription([input_desc], [output_desc]) - lr_desc = ort_trainer_learning_rate_description() - model = ORTTrainer( - model, - None, - model_desc, - "LambOptimizer", - map_optimizer_attributes, - lr_desc, - device, - postprocess_model=process_dropout, - world_rank=0, - world_size=1, - ) - input = torch.ones(dim_size, dtype=torch.float32).to(device) - expected_training_output = [0.0] - expected_eval_output = [1.0] - learning_rate = torch.tensor([1.0000000e00]).to(device) - input_args = [input, learning_rate] - train_output = model.train_step(*input_args) - - rtol = 1e-04 - assert_allclose( - expected_training_output, - train_output.item(), - rtol=rtol, - err_msg="dropout training loss mismatch", - ) - - eval_output = model.eval_step(input) - assert_allclose( - expected_eval_output, - eval_output.item(), - rtol=rtol, - err_msg="dropout eval loss mismatch", - ) - - # Do another train step to make sure it's using original ratios - train_output_2 = model.train_step(*input_args) - assert_allclose( - expected_training_output, - train_output_2.item(), - rtol=rtol, - err_msg="dropout training loss 2 mismatch", - ) - - -if __name__ == "__main__": - unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py b/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py deleted file mode 100644 index 3d3feca06a99b..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py +++ /dev/null @@ -1,56 +0,0 @@ -import numpy as np -from onnx import numpy_helper - - -def get_node_index(model, node): - i = 0 - while i < len(model.graph.node): - if model.graph.node[i] == node: - break - i += 1 - return i if i < len(model.graph.node) else None - - -def add_const(model, name, output, t_value=None, f_value=None): - const_node = model.graph.node.add() - const_node.op_type = "Constant" - const_node.name = name - const_node.output.extend([output]) - attr = const_node.attribute.add() - attr.name = "value" - if t_value is not None: - attr.type = 4 - attr.t.CopyFrom(t_value) - else: - attr.type = 1 - attr.f = f_value - return const_node - - -def process_dropout(model): - dropouts = [] - index = 0 - for node in model.graph.node: - if node.op_type == "Dropout": - new_dropout = model.graph.node.add() - new_dropout.op_type = "TrainableDropout" - new_dropout.name = "TrainableDropout_%d" % index - # make ratio node - ratio = np.asarray([node.attribute[0].f], dtype=np.float32) - print(ratio.shape) - ratio_value = numpy_helper.from_array(ratio) - ratio_node = add_const( - model, - "dropout_node_ratio_%d" % index, - "dropout_node_ratio_%d" % index, - t_value=ratio_value, - ) - print(ratio_node) - new_dropout.input.extend([node.input[0], ratio_node.output[0]]) - new_dropout.output.extend(node.output) - dropouts.append(get_node_index(model, node)) - index += 1 - dropouts.sort(reverse=True) - for d in dropouts: - del model.graph.node[d] - model.opset_import[0].version = 10 diff --git a/onnxruntime/test/python/transformers/conformer_model_generator.py b/onnxruntime/test/python/transformers/conformer_model_generator.py new file mode 100644 index 0000000000000..71e4f2b63cf4f --- /dev/null +++ b/onnxruntime/test/python/transformers/conformer_model_generator.py @@ -0,0 +1,543 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from typing import List + +import numpy as np +import onnx +from bert_model_generator import float_tensor +from onnx import TensorProto, helper, numpy_helper + + +# Adapted from bert_model_generator.py +def get_tensor_and_weight(name: str, shape: List[int], random=False, zeros=False): + low = 0.0 + high = 1.0 + total_elements = 1 + for x in shape: + total_elements *= x + weights = ( + [np.random.uniform(low, high) for _ in range(total_elements)] + if random + else [0.0] * total_elements + if zeros + else [1.0] * total_elements + ) + return helper.make_tensor(name, TensorProto.FLOAT, shape, weights), weights + + +def create_conformer_attention( + hidden_size=512, + num_heads=8, + epsilon=0.000009999999747378752, + add_before_layernorm=False, + fused=False, +): + # Get head size and ensure head size is an integer + assert hidden_size % num_heads == 0 + head_size = hidden_size // num_heads + + # Construct input and output nodes + inputs = [ + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["batch_size", 8, 512]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, ["batch_size", 8, 512]), + helper.make_tensor_value_info("inp_cache_k", TensorProto.FLOAT, [24, "batch_size", 8, 72, head_size]), + helper.make_tensor_value_info("inp_cache_v", TensorProto.FLOAT, [24, "batch_size", 8, 72, head_size]), + ] + outputs = [ + helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", 8, hidden_size]), + helper.make_tensor_value_info("output_1", TensorProto.FLOAT, ["batch_size", 8, 512]), + helper.make_tensor_value_info("oup_cache_k", TensorProto.FLOAT, ["batch_size", 8, 80, 64]), + helper.make_tensor_value_info("oup_cache_v", TensorProto.FLOAT, ["batch_size", 8, 80, 64]), + ] + nodes = [] + + # Create layernorm (Add + LayerNorm or SkipLayerNorm) + if add_before_layernorm: + nodes.extend( + [ + helper.make_node( + "Add", ["input_0", "input_1"], ["layernorm_output_to_skiplayernorm"], "add_before_layernorm" + ), + helper.make_node( + "LayerNormalization", + ["layernorm_output_to_skiplayernorm", "layernorm_weight", "layernorm_bias"], + ["layernorm_add_output_to_matmul"], + "layernorm", + epsilon=epsilon, + ), + ] + ) + else: + nodes.append( + helper.make_node( + "SkipLayerNormalization", + ["input_0", "input_1", "layernorm_weight", "layernorm_bias"], + ["layernorm_add_output_to_matmul", "", "", "layernorm_add_output_to_skiplayernorm"], + "skiplayernorm", + domain="com.microsoft", + epsilon=epsilon, + ) + ) + + if fused: + fused_q_nodes = [ + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "q_weight"], + ["q_matmul_output"], + "q_path_matmul", + ), + helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), + helper.make_node( + "Reshape", ["q_add_output", "k_attn_heads_output"], ["q_4d_bsnh"], "q_reshape_to_4d", allowzero=0 + ), + helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), + helper.make_node( + "Div", + ["q_4d_bnsh", "q_scale"], + ["q_div_output"], + "q_div_by_sqrt_head_size", + ), + ] + nodes.extend(fused_q_nodes) + nodes.extend( + [ + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "k_weight"], + ["k_matmul_output"], + "k_path_matmul", + ), + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "v_weight"], + ["v_matmul_output"], + "v_path_matmul", + ), + helper.make_node( + "Reshape", ["q_div_output", "position_embed_output"], ["reshape_pos_emb"], "r_pos_emb", allowzero=0 + ), + helper.make_node( + "Transpose", ["reshape_pos_emb"], ["transpose_reshape_pos_emb"], "p_transpose", perm=[1, 0, 2] + ), + helper.make_node( + "MatMul", + ["transpose_reshape_pos_emb", "transpose_reshape_pos_emb"], + ["pos_matmul"], + "pos_embed_matmul", + ), + helper.make_node( + "Transpose", ["pos_matmul"], ["transpose_pos_matmul"], "p_matmul_transpose", perm=[1, 0, 2] + ), + helper.make_node( + "Reshape", + ["transpose_pos_matmul", "position_embed_output"], + ["reshape_position_emb"], + "final_reshape_pos_emb", + allowzero=0, + ), + helper.make_node( + "MultiHeadAttention", + [ + "q_matmul_output", + "k_matmul_output", + "v_matmul_output", + "Attention_0_qkv_bias", + "", + "reshape_position_emb", + "gather_past_k_output", + "gather_past_v_output", + ], + ["attn_output", "oup_cache_k", "oup_cache_v"], + "Attention_0", + domain="com.microsoft", + num_heads=num_heads, + ), + ] + ) + # Create nodes used with qkv concats, reshapes, and transposes + nodes.extend( + [ + helper.make_node("Shape", ["layernorm_add_output_to_matmul"], ["shape_output"], "shape", start=0), + helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), + helper.make_node( + "Mul", + ["gather_0_output", "num_heads_int"], + ["mul_attn_heads_output"], + "mul_num_heads", + ), + helper.make_node( + "Unsqueeze", + ["mul_attn_heads_output", "unsqueeze_axes_input"], + ["unsqueeze_position_embed"], + "unsqueeze_position_embed", + ), + helper.make_node( + "Concat", + ["unsqueeze_position_embed", "neg_one", "head_size"], + ["position_embed_output"], + "position_embed_concat_output", + axis=0, + ), + helper.make_node( + "Unsqueeze", + ["gather_0_output", "unsqueeze_axes_input"], + ["unsqueeze_attn_heads_output"], + "unsqueeze_num_heads", + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], + ["k_attn_heads_output"], + "k_num_heads", + axis=0, + ), + ] + ) + + nodes.extend( + [ + helper.make_node("Gather", ["inp_cache_v", "idx_0"], ["gather_past_v_output"], "gather_past_v", axis=0), + helper.make_node("Gather", ["inp_cache_k", "idx_0"], ["gather_past_k_output"], "gather_past_k", axis=0), + ] + ) + else: + # Create nodes for Q/K/V paths + q_nodes = [ + helper.make_node( + "MatMul", ["layernorm_add_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul" + ), + helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), + helper.make_node("Reshape", ["q_add_output", "q_attn_heads_output"], ["q_4d_bsnh"], "q_reshape_to_4d"), + helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), + helper.make_node( + "Div", + ["q_4d_bnsh", "q_scale"], + ["q_div_output"], + "q_div_by_sqrt_head_size", + ), + ] + k_nodes = [ + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "k_weight"], + ["k_matmul_output"], + "k_path_matmul", + ), + helper.make_node("Add", ["k_bias", "k_matmul_output"], ["k_add_output"], "k_path_add"), + helper.make_node("Reshape", ["k_add_output", "k_attn_heads_output"], ["k_4d_bsnh"], "k_reshape_to_4d"), + helper.make_node("Transpose", ["k_4d_bsnh"], ["k_4d_bnsh"], "k_transpose_to_bnsh", perm=[0, 2, 1, 3]), + helper.make_node( + "Concat", + ["gather_past_k_output", "k_4d_bnsh"], + ["oup_cache_k"], + "concat_past_k_and_curr_k", + axis=2, + ), + helper.make_node( + "Transpose", + ["oup_cache_k"], + ["k_output_transpose"], + "k_transpose_last_two_dims", + perm=[0, 1, 3, 2], + ), + ] + v_nodes = [ + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "v_weight"], + ["v_matmul_output"], + "v_path_matmul", + ), + helper.make_node("Add", ["v_bias", "v_matmul_output"], ["v_add_output"], "v_path_add"), + helper.make_node("Reshape", ["v_add_output", "v_attn_heads_output"], ["v_4d_bsnh"], "v_reshape_to_4d"), + helper.make_node("Transpose", ["v_4d_bsnh"], ["v_4d_bnsh"], "v_transpose_to_bnsh", perm=[0, 2, 1, 3]), + helper.make_node( + "Concat", + ["gather_past_v_output", "v_4d_bnsh"], + ["oup_cache_v"], + "concat_past_v_and_curr_v", + axis=2, + ), + ] + pos_embed = [ + helper.make_node("Reshape", ["q_div_output", "position_embed_output"], ["reshape_pos_emb"], "r_pos_emb"), + helper.make_node( + "Transpose", ["reshape_pos_emb"], ["transpose_reshape_pos_emb"], "p_transpose", perm=[1, 0, 2] + ), + helper.make_node( + "MatMul", + ["transpose_reshape_pos_emb", "transpose_reshape_pos_emb"], + ["pos_matmul"], + "pos_embed_matmul", + ), + helper.make_node( + "Transpose", ["pos_matmul"], ["transpose_pos_matmul"], "p_matmul_transpose", perm=[1, 0, 2] + ), + helper.make_node( + "Reshape", + ["transpose_pos_matmul", "position_embed_output"], + ["reshape_position_emb"], + "final_reshape_pos_emb", + ), + ] + nodes.extend(q_nodes) + nodes.extend(k_nodes) + nodes.extend(v_nodes) + nodes.extend(pos_embed) + + # Create nodes used with qkv concats, reshapes, and transposes + nodes.extend( + [ + helper.make_node("Shape", ["layernorm_add_output_to_matmul"], ["shape_output"], "shape", start=0), + helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), + helper.make_node( + "Mul", + ["gather_0_output", "num_heads_int"], + ["mul_attn_heads_output"], + "mul_num_heads", + ), + helper.make_node( + "Unsqueeze", + ["mul_attn_heads_output", "unsqueeze_axes_input"], + ["unsqueeze_position_embed"], + "unsqueeze_position_embed", + ), + helper.make_node( + "Concat", + ["unsqueeze_position_embed", "neg_one", "head_size"], + ["position_embed_output"], + "position_embed_concat_output", + axis=0, + ), + helper.make_node( + "Unsqueeze", + ["gather_0_output", "unsqueeze_axes_input"], + ["unsqueeze_attn_heads_output"], + "unsqueeze_num_heads", + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], + ["q_attn_heads_output"], + "q_num_heads", + axis=0, + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], + ["k_attn_heads_output"], + "k_num_heads", + axis=0, + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], + ["v_attn_heads_output"], + "v_num_heads", + axis=0, + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size"], + ["bsd_format"], + axis=0, + ), + helper.make_node( + "Constant", + inputs=[], + outputs=["q_bsnh_reshape"], + value=numpy_helper.from_array( + np.array([0, 0, num_heads, head_size], dtype="int64"), name="const_tensor" + ), + ), + ] + ) + + nodes.extend( + [ + helper.make_node("Gather", ["inp_cache_v", "idx_0"], ["gather_past_v_output"], "gather_past_v", axis=0), + helper.make_node("Gather", ["inp_cache_k", "idx_0"], ["gather_past_k_output"], "gather_past_k", axis=0), + ] + ) + + # Compute Q x K' + nodes.extend( + [ + helper.make_node( + "MatMul", + [ + "q_div_output", + "k_output_transpose", + ], + ["qk_output"], + "matmul_qk", + ) + ] + ) + + # Create nodes for computing softmax(Q x K') x V + nodes.extend( + [ + helper.make_node( + "Add", + [ + "qk_output", + "reshape_position_emb", + ], + ["add_qk_output"], + "add_qk", + ), + helper.make_node( + "Softmax", + ["add_qk_output"], + ["softmax_output"], + "softmax_qk", + axis=2, + ), + helper.make_node( + "MatMul", + ["softmax_output", "oup_cache_v"], + ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], + "matmul_qkv", + ), + helper.make_node( + "Transpose", + ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], + ["qkv_bsnh"], + "transpose_bnsh_to_bsnh", + perm=[0, 2, 1, 3], + ), + helper.make_node("Reshape", ["qkv_bsnh", "bsd_format"], ["attn_output"], "qkv_bsd"), + ] + ) + + # Create final nodes to conclude attention + nodes.append( + helper.make_node( + "MatMul", + ["attn_output", "matmul_after_attn_initializer"], + ["matmul_after_attn_output"], + "matmul_after_attn", + ), + ) + if not fused: + next_sln_inputs = [ + "layernorm_add_output_to_skiplayernorm", + "add_after_attn_output", + "layernorm_weight", + "layernorm_bias", + ] + nodes.extend( + [ + helper.make_node( + "Add", + ["add_after_attn_initializer", "matmul_after_attn_output"], + ["add_after_attn_output"], + "add_after_attn", + ), + helper.make_node( + "SkipLayerNormalization", + next_sln_inputs, + ["output_0", "", "", "output_1"], + "next_skiplayernorm", + domain="com.microsoft", + epsilon=epsilon, + ), + ] + ) + else: + next_sln_inputs = [ + "matmul_after_attn_output", + "layernorm_add_output_to_skiplayernorm", + "layernorm_weight", + "layernorm_bias", + "add_after_attn_initializer", + ] + nodes.append( + helper.make_node( + "SkipLayerNormalization", + next_sln_inputs, + ["output_0", "", "", "output_1"], + "SkipLayerNorm_AddBias_0", + domain="com.microsoft", + epsilon=epsilon, + ) + ) + + # Create initializers + v_weight, v_weight_data = get_tensor_and_weight("v_weight", [hidden_size, hidden_size]) + v_bias, v_bias_data = get_tensor_and_weight("v_bias", [hidden_size]) + q_weight, q_weight_data = get_tensor_and_weight("q_weight", [hidden_size, hidden_size]) + q_bias, q_bias_data = get_tensor_and_weight("q_bias", [hidden_size]) + k_weight, k_weight_data = get_tensor_and_weight("k_weight", [hidden_size, hidden_size]) + k_bias, k_bias_data = get_tensor_and_weight("k_bias", [hidden_size]) + + qkv_bias = helper.make_tensor( + "Attention_0_qkv_bias", + TensorProto.FLOAT, + [3 * hidden_size], + q_bias_data + k_bias_data + v_bias_data, + ) + initializers = [ + float_tensor("layernorm_weight", [hidden_size]), + float_tensor("layernorm_bias", [hidden_size]), + float_tensor("matmul_after_attn_initializer", [hidden_size, hidden_size]), + float_tensor("add_after_attn_initializer", [hidden_size]), + ] + + # Add Q/K/V weight tensors as initializers + if fused: + initializers.extend([q_weight, k_weight, v_weight]) + initializers.extend([q_bias]) + initializers.append(qkv_bias) + initializers.extend( + [ + numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), + numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), + numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), + numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), + numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), + numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), + numpy_helper.from_array(np.array([0, 0, num_heads, head_size], dtype="int64"), name="q_bsnh_reshape"), + ] + ) + else: + initializers.extend([q_weight, k_weight, v_weight]) + + initializers.extend([q_bias, k_bias, v_bias]) + + initializers.extend( + [ + numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), + numpy_helper.from_array(np.array([num_heads], dtype="int64"), name="num_heads"), + numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), + numpy_helper.from_array(np.array([hidden_size], dtype="int64"), name="hidden_size"), + numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), + numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), + numpy_helper.from_array(np.array(1, dtype="int64"), name="idx_1"), + numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), + numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), + ] + ) + + # Construct graph + graph = helper.make_graph(nodes, "conformer_self_mha_graph", inputs, outputs, initializers, doc_string="conformer") + opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) + return helper.make_model(graph, opset_imports=(opsetid,)) + + +if __name__ == "__main__": + np.random.seed(2) + num_heads = 8 + hidden_size = 512 + + model = create_conformer_attention(num_heads=num_heads, hidden_size=hidden_size) + onnx.save(model, "conformer_self_mha.onnx") + + model = create_conformer_attention(num_heads=num_heads, hidden_size=hidden_size, fused=True) + onnx.save(model, "./test_data/models/conformer/conformer_self_mha_fused.onnx") diff --git a/onnxruntime/test/python/transformers/test_conformer.py b/onnxruntime/test/python/transformers/test_conformer.py new file mode 100644 index 0000000000000..471ba9756bcf8 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_conformer.py @@ -0,0 +1,69 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import unittest + +import onnx +from conformer_model_generator import create_conformer_attention +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from fusion_options import FusionOptions + from onnx_model import OnnxModel + from optimizer import optimize_model +else: + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.optimizer import optimize_model + + +class TestFusion(unittest.TestCase): + def verify_fusion(self, optimized_model, expected_model_filename): + optimized_model.topological_sort(is_deterministic=True) + + expected_model_path = os.path.join( + os.path.dirname(__file__), "test_data", "models", "conformer", expected_model_filename + ) + print("Expected model path = ", expected_model_path) + expected_model = OnnxModel(onnx.load(expected_model_path)) + expected_model.topological_sort(is_deterministic=True) + + nodes = optimized_model.model.graph.node + self.assertEqual(len(nodes), len(expected_model.model.graph.node)) + + for i in range(len(nodes)): + self.assertEqual(nodes[i], expected_model.model.graph.node[i]) + + for expected_initializer in expected_model.model.graph.initializer: + print("Expected initializer initial = ", expected_initializer.name) + self.assertTrue( + OnnxModel.has_same_value( + optimized_model.get_initializer(expected_initializer.name), expected_initializer + ) + ) + + def test_ct_mha_fusion(self): + num_heads = 8 + hidden_size = 512 + model = create_conformer_attention(num_heads=num_heads, hidden_size=hidden_size, add_before_layernorm=False) + dir = "." + model_path = os.path.join(dir, "conformer_self_mha.onnx") + onnx.save(model, model_path) + options = FusionOptions("conformer") + optimized_model = optimize_model( + model_path, + model_type="conformer", + num_heads=num_heads, + hidden_size=hidden_size, + optimization_options=options, + ) + os.remove(model_path) + self.verify_fusion(optimized_model, "conformer_self_mha_fused.onnx") + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_data/models/conformer/conformer_self_mha_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/conformer/conformer_self_mha_fused.onnx new file mode 100644 index 0000000000000..9d882751db265 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/conformer/conformer_self_mha_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index 99f62ffdb9f53..8a839875de2a2 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -183,7 +183,9 @@ def create_multihead_attention_graph(config): return model.SerializeToString() -def create_group_query_attention_graph_prompt(config, past_kv_format=Formats.BSNH, share_buffer=True): +def create_group_query_attention_graph_prompt( + config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1 +): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length nodes = [ @@ -202,6 +204,7 @@ def create_group_query_attention_graph_prompt(config, past_kv_format=Formats.BSN "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, + local_window_size=local_window_size, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -297,6 +300,26 @@ def create_group_query_attention_graph_prompt(config, past_kv_format=Formats.BSN config.head_size, ], ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + ], + ), ] graph = helper.make_graph( @@ -310,7 +333,9 @@ def create_group_query_attention_graph_prompt(config, past_kv_format=Formats.BSN return model.SerializeToString() -def create_group_query_attention_graph_past(config, past_kv_format=Formats.BSNH, share_buffer=True): +def create_group_query_attention_graph_past( + config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1 +): past_kv_seqlen = config.kv_sequence_length present_kv_seqlen = ( config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length @@ -331,6 +356,7 @@ def create_group_query_attention_graph_past(config, past_kv_format=Formats.BSNH, "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, + local_window_size=local_window_size, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -636,8 +662,12 @@ def mha_func(q, k, v, config): return output -def gqa_prompt_func(q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True): - onnx_model_str = create_group_query_attention_graph_prompt(config, past_kv_format, share_buffer) +def gqa_prompt_func( + q, k, v, config, new_k, new_v, seqlens_k=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True +): + onnx_model_str = create_group_query_attention_graph_prompt( + config, past_kv_format, share_buffer, local_window_size=window_size + ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None past_v = v.clone() if share_buffer else None @@ -706,8 +736,12 @@ def gqa_prompt_func(q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_forma return output, present_k, present_v -def gqa_past_func(q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True): - onnx_model_str = create_group_query_attention_graph_past(config, past_kv_format, share_buffer) +def gqa_past_func( + q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1 +): + onnx_model_str = create_group_query_attention_graph_past( + config, past_kv_format, share_buffer, local_window_size=window_size + ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() past_v = v.clone() @@ -796,6 +830,28 @@ def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_paddi return col_idx > row_idx + sk - sq +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + + def attention_ref( q, k, @@ -805,6 +861,7 @@ def attention_ref( dropout_p=0.0, dropout_mask=None, causal=False, + window_size=(-1, -1), # -1 means infinite window size upcast=True, reorder_ops=False, ): @@ -817,6 +874,8 @@ def attention_ref( key_padding_mask: (batch_size, seqlen_k) dropout_p: float dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast output back to fp16/bf16. reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) @@ -826,6 +885,8 @@ def attention_ref( output: (batch_size, seqlen_q, nheads, head_dim) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ + if causal: + window_size = (window_size[0], 0) dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() @@ -839,12 +900,24 @@ def attention_ref( scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - if causal: - causal_mask = construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device) - scores.masked_fill_(causal_mask, float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) attention = torch.softmax(scores, dim=-1) - if causal: # Some rows are completely masked out so we fill them with zero instead of NaN - attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) if dropout_mask is not None: attention_drop = attention.masked_fill(~dropout_mask, 0.0) @@ -853,7 +926,6 @@ def attention_ref( output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) @@ -957,6 +1029,8 @@ def parity_check_mha( def parity_check_gqa_prompt( config, + causal=False, + local=False, past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, @@ -1007,6 +1081,15 @@ def parity_check_gqa_prompt( requires_grad=False, ) + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(0, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) + # Pytorch to compare k_cache_ref = k.clone() v_cache_ref = v.clone() @@ -1033,14 +1116,18 @@ def parity_check_gqa_prompt( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) + out_ref, _ = attention_ref( + q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_prompt_func(q, k, v, config, new_k, new_v, cache_seqlens, past_format, True) + out, present_k, present_v = gqa_prompt_func( + q, k, v, config, new_k, new_v, cache_seqlens, left_window_size, past_format, True + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1052,6 +1139,10 @@ def parity_check_gqa_prompt( # Compare results print( "KV-buffer", + " causal:", + causal, + " local:", + local, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1080,6 +1171,8 @@ def parity_check_gqa_prompt( def parity_check_gqa_prompt_no_buff( config, + causal=False, + local=False, past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, @@ -1112,6 +1205,15 @@ def parity_check_gqa_prompt_no_buff( requires_grad=False, ) + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(0, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) + # Pytorch to compare k_cache_ref = new_k.clone() v_cache_ref = new_v.clone() @@ -1132,14 +1234,18 @@ def parity_check_gqa_prompt_no_buff( new_mask = brange < cache_seqlens_expanded k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True) + out_ref, _ = attention_ref( + q, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size + ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_prompt_func(q, None, None, config, new_k, new_v, cache_seqlens, past_format, False) + out, present_k, present_v = gqa_prompt_func( + q, None, None, config, new_k, new_v, cache_seqlens, left_window_size, past_format, False + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1179,6 +1285,8 @@ def parity_check_gqa_prompt_no_buff( def parity_check_gqa_past( config, + causal=False, + local=False, past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, @@ -1228,6 +1336,14 @@ def parity_check_gqa_past( dtype=torch.float16, requires_grad=False, ) + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(0, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) # Pytorch to compare k_cache_ref = k.clone() @@ -1253,14 +1369,18 @@ def parity_check_gqa_past( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) + out_ref, _ = attention_ref( + q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, cache_seqlens, past_format, True) + out, present_k, present_v = gqa_past_func( + q, k, v, config, new_k, new_v, cache_seqlens, past_format, True, left_window_size + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1274,6 +1394,10 @@ def parity_check_gqa_past( "KV-buffer", "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", + " causal:", + causal, + " local:", + local, " B:", config.batch_size, " S:", @@ -1300,6 +1424,8 @@ def parity_check_gqa_past( def parity_check_gqa_past_no_buff( config, + causal=False, + local=False, past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, @@ -1351,6 +1477,15 @@ def parity_check_gqa_past_no_buff( requires_grad=False, ) + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(0, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) + # Pytorch to compare k_cache_ref = k.clone() v_cache_ref = v.clone() @@ -1378,14 +1513,18 @@ def parity_check_gqa_past_no_buff( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) + out_ref, _ = attention_ref( + q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, cache_seqlens, past_format, False) + out, present_k, present_v = gqa_past_func( + q, k, v, config, new_k, new_v, cache_seqlens, past_format, False, window_size=left_window_size + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1401,142 +1540,10 @@ def parity_check_gqa_past_no_buff( # Compare results print( "NO buff", - "past kv format:", - "BSNH" if past_format == Formats.BSNH else "BNSH", - " B:", - config.batch_size, - " S:", - config.sequence_length, - " kv S:", - config.kv_sequence_length, - " N:", - config.num_heads, - " kv N:", - config.kv_num_heads, - " h:", - config.head_size, - " Mean Error:", - numpy.mean(numpy.abs(out - out_ref)), - numpy.allclose( - out, - out_ref, - rtol=rtol, - atol=atol, - equal_nan=True, - ), - ) - - -def parity_check_gqa_past_no_buff_no_mask( - config, - past_format=Formats.BSNH, - rtol=1e-3, - atol=1e-3, -): - q = torch.randn( - config.batch_size, - config.sequence_length, - config.num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - k = torch.randn( - config.batch_size, - config.past_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.past_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - v = torch.randn( - config.batch_size, - config.past_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.past_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_k = torch.randn( - config.batch_size, - config.sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_v = torch.randn( - config.batch_size, - config.sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - - # Pytorch to compare - k_cache_ref = k.clone() - v_cache_ref = v.clone() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - k_cache_ref = torch.cat((k_cache_ref, new_k), 1) - v_cache_ref = torch.cat((v_cache_ref, new_v), 1) - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - key_padding_mask = None - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) - out_ref = out_ref.detach().cpu().numpy() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - - # Flash function - out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, past_format, False) - out = torch.squeeze(out, 0) - out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) - out = out.detach().cpu().numpy() - - # Make sure past-present buffer updating correctly - if past_format == Formats.BSNH: - assert numpy.allclose( - present_k, - k_cache_ref.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - ) - assert numpy.allclose( - present_v, - v_cache_ref.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - ) - else: - assert numpy.allclose( - present_k, - k_cache_ref.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - ) - assert numpy.allclose( - present_v, - v_cache_ref.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - ) - - # Compare results - print( - "Unbuffered", + " causal:", + causal, + " local:", + local, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1663,10 +1670,11 @@ def test_gqa_no_past(self): for sq, skv in seqs: for n, n2 in num_h: for h in h_sizes: - for past_kv_format in [Formats.BNSH]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - parity_check_gqa_prompt(config, past_format=past_kv_format) - parity_check_gqa_prompt_no_buff(config, past_format=past_kv_format) + for local in [False, True]: + for past_kv_format in [Formats.BNSH]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + parity_check_gqa_prompt(config, local=local, past_format=past_kv_format) + parity_check_gqa_prompt_no_buff(config, local=local, past_format=past_kv_format) def test_gqa_past(self): if not torch.cuda.is_available(): @@ -1725,24 +1733,25 @@ def test_gqa_past(self): for s, s2 in seqs: for n, n2 in num_h: for h in h_sizes: - for past_kv_format in [Formats.BNSH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + for local in [False, True]: + for past_kv_format in [Formats.BNSH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) if __name__ == "__main__": unittest.main() - # test_gqa = TestGQA() - # test_gqa.test_gqa_past() diff --git a/onnxruntime/test/python/transformers/test_parity_moe.py b/onnxruntime/test/python/transformers/test_parity_moe.py new file mode 100644 index 0000000000000..72ca5d9975c05 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_parity_moe.py @@ -0,0 +1,431 @@ +# -------------------------------------------------------------------------- +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +import time +import unittest + +import numpy +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from onnx import TensorProto, helper + +import onnxruntime + +torch.manual_seed(42) +numpy.random.seed(42) + + +ORT_DTYPE = TensorProto.FLOAT16 +NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 +THRESHOLD = 3e-2 + + +def value_string_of(numpy_array): + arr = numpy_array.flatten() + lines = ["f, ".join([str(v) for v in arr[i : min(i + 8, arr.size)]]) for i in range(0, arr.size, 8)] + return "{\n " + "f,\n ".join(lines) + "f}" + + +def print_tensor(name, numpy_array): + print(f"const std::vector {name} = {value_string_of(numpy_array)};") + + +def create_moe_onnx_graph( + num_rows, + num_experts, + hidden_size, + inter_size, + fc1_experts_weights, + fc2_experts_weights, + fc1_experts_bias, + fc2_experts_bias, +): + nodes = [ + helper.make_node( + "MoE", + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc2_experts_weights", + "fc1_experts_bias", + "fc2_experts_bias", + ], + ["output"], + "MoE_0", + k=1, + activation_type="gelu", + domain="com.microsoft", + ), + ] + + fc1_shape = [num_experts, hidden_size, inter_size] + fc2_shape = [num_experts, inter_size, hidden_size] + + torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 + + initializers = [ + helper.make_tensor( + "fc1_experts_weights", + ORT_DTYPE, + fc1_shape, + fc1_experts_weights.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_weights", + ORT_DTYPE, + fc2_shape, + fc2_experts_weights.to(torch_type).flatten().tolist(), + raw=False, + ), + ] + + fc1_bias_shape = [num_experts, inter_size] + fc2_bias_shape = [num_experts, hidden_size] + initializers.extend( + [ + helper.make_tensor( + "fc1_experts_bias", + ORT_DTYPE, + fc1_bias_shape, + fc1_experts_bias.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_bias", + ORT_DTYPE, + fc2_bias_shape, + fc2_experts_bias.to(torch_type).flatten().tolist(), + raw=False, + ), + ] + ) + + graph_inputs = [ + helper.make_tensor_value_info("input", ORT_DTYPE, [num_rows, hidden_size]), + ] + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + ORT_DTYPE, + [num_rows, num_experts], + ) + ) + + graph_outputs = [ + helper.make_tensor_value_info("output", ORT_DTYPE, [num_rows, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def get_activation_fn(activation): + if activation == "relu": + return nn.ReLU + elif activation == "gelu": + return nn.GELU + else: + raise NotImplementedError + + +class MoEGate(nn.Module): + def __init__(self, num_experts, in_features): + super().__init__() + self.wg_reduction = torch.nn.Linear(in_features, 16, bias=False) + + wg = torch.empty(num_experts, 16) + torch.nn.init.orthogonal_(wg, gain=0.32) + self.register_parameter("wg", torch.nn.Parameter(wg)) + + def forward(self, input): + input = self.wg_reduction(input) + with torch.no_grad(): + wg_norm = self.wg.norm(p=2.0, dim=1, keepdim=True) + self.wg.mul_(1.5 / wg_norm) + logits = self._cosine(input, self.wg) + return logits + + def _cosine(self, mat1, mat2, eps=1e-4): + assert mat1.dim() == 2 + assert mat2.dim() == 2 + + mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps) + return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1) + + +class MoERuntimeExperts(nn.Module): + def __init__( + self, + num_experts, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + bias=True, + chunk_size=-1, + ): + super().__init__() + # assert bias is False, "Current bias is not supported" + assert drop == 0.0, "Current drop is not supported" + assert chunk_size == -1, "Current chunk is not supported" + + self.weight1 = nn.Parameter(torch.rand(num_experts, in_features, hidden_features)) + self.weight2 = nn.Parameter(torch.rand(num_experts, hidden_features, out_features)) + + self.bias1 = nn.Parameter(torch.rand(num_experts, hidden_features)) if bias else None + self.bias2 = nn.Parameter(torch.rand(num_experts, in_features)) if bias else None + + self.act = act_layer() + + def forward(self, x, indices_s): + x = x.unsqueeze(1) + x = self.bmm(x, self.weight1, indices_s) + if self.bias1 is not None: + x = x + self.bias1[indices_s].unsqueeze(1) # S x hidden_features + x = self.act(x) + x = self.bmm(x, self.weight2, indices_s) + if self.bias2 is not None: + x = x + self.bias2[indices_s].unsqueeze(1) # S x 1 x in_features + return x + + def bmm(self, x, weight, indices_s): + x = torch.bmm(x, weight[indices_s]) # S x 1 x hidden_features + return x + + +class MoE(nn.Module): + def __init__( + self, + batch_size, + num_rows, + num_experts, + in_features, + hidden_features=None, + out_features=None, + eval_capacity=-1, + activation="gelu", + ): + super().__init__() + self.num_experts = num_experts + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.eval_capacity = eval_capacity # -1 means we route all tokens + + self.gate = MoEGate(num_experts=num_experts, in_features=in_features) + self.moe_experts = MoERuntimeExperts( + num_experts=num_experts, + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + act_layer=get_activation_fn(activation), + bias=True, + ) + + self.moe_onnx_graph = create_moe_onnx_graph( + batch_size * num_rows, + num_experts, + in_features, + hidden_features, + self.moe_experts.weight1, + self.moe_experts.weight2, + self.moe_experts.bias1, + self.moe_experts.bias2, + ) + + self.ort_sess = self.create_ort_session() + + self.torch_input = torch.randn(batch_size, num_rows, in_features) + + def create_ort_session(self): + from onnxruntime import InferenceSession, SessionOptions + + sess_options = SessionOptions() + + cuda_providers = ["CUDAExecutionProvider"] + if cuda_providers[0] not in onnxruntime.get_available_providers(): + return None + + sess_options.log_severity_level = 2 + ort_session = InferenceSession(self.moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"]) + + return ort_session + + def ort_run_with_iobinding(self, ort_inputs, repeat=1000): + iobinding = self.ort_sess.io_binding() + device_id = torch.cuda.current_device() + + iobinding.bind_input( + name="input", + device_type="cuda", + device_id=device_id, + element_type=NP_TYPE, + shape=ort_inputs["input"].shape, + buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(ort_inputs["input"], "cuda", device_id).data_ptr(), + ) + iobinding.bind_input( + name="router_probs", + device_type="cuda", + device_id=device_id, + element_type=NP_TYPE, + shape=ort_inputs["router_probs"].shape, + buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( + ort_inputs["router_probs"], "cuda", device_id + ).data_ptr(), + ) + + iobinding.synchronize_inputs() + + iobinding.bind_output( + name="output", + device_type="cuda", + device_id=device_id, + element_type=NP_TYPE, + shape=ort_inputs["input"].shape, + buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( + numpy.zeros(ort_inputs["input"].shape), "cuda", device_id + ).data_ptr(), + ) + iobinding.synchronize_outputs() + + s = time.time() + for _ in range(repeat): + self.ort_sess.run_with_iobinding(iobinding) + e = time.time() + print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms") + + def torch_forward(self): + x = self.torch_input + + b, t, c = x.shape + x = x.reshape(-1, c) + logits = self.gate(x) + gates = torch.nn.functional.softmax(logits, dim=1) + ret = torch.max(gates, dim=1) + indices_s = ret.indices # dim: [bs], the index of the expert with highest softmax value + scores = ret.values.unsqueeze(-1).unsqueeze(-1) # S + x = self.moe_experts(x, indices_s) + + x = x * scores + x = x.reshape(b * t, c) + + return x, torch.sum(x) + + def onnx_forward(self, iobinding=False): + x = self.torch_input + + _, _, c = x.shape + y = x.reshape(-1, c) + logits = self.gate(y) + + ort_inputs = { + "input": numpy.ascontiguousarray(y.detach().numpy().astype(NP_TYPE)), + "router_probs": numpy.ascontiguousarray(logits.detach().numpy().astype(NP_TYPE)), + } + + ort_output = None + if self.ort_sess is not None: + if not iobinding: + ort_output = self.ort_sess.run(None, ort_inputs) + else: + self.ort_run_with_iobinding(ort_inputs) + return None + + # print_tensor("input", ort_inputs["input"]) + # print_tensor("router_probs", ort_inputs["router_probs"]) + # print_tensor("fc1_experts_weights", self.moe_experts.weight1.detach().numpy()) + # print_tensor("fc2_experts_weights", self.moe_experts.weight2.detach().numpy()) + # print_tensor("fc1_experts_bias", self.moe_experts.bias1.detach().numpy()) + # print_tensor("fc2_experts_bias", self.moe_experts.bias2.detach().numpy()) + # print_tensor("output", ort_output[0]) + + return ort_output + + def parity_check(self): + torch_out = self.torch_forward() + ort_out = self.onnx_forward() + if ort_out is not None: + # print("max diff", numpy.max(numpy.abs(torch_out[0].detach().numpy() - ort_out[0]))) + assert numpy.allclose(torch_out[0].detach().numpy(), ort_out[0], rtol=THRESHOLD, atol=THRESHOLD) + + def benchmark(self): + self.onnx_forward(iobinding=True) + + +class TestMoE(unittest.TestCase): + def test_moe_small(self): + rt = MoE( + batch_size=2, + num_rows=8, + num_experts=4, + in_features=16, + hidden_features=32, + out_features=16, + ) + rt.parity_check() + + @pytest.mark.slow + def test_moe_large(self): + for batch_size in [1, 8]: + for num_rows in [16, 64]: + for num_experts in [16, 64]: + for in_features in [256]: + for hidden_features in [512]: + print( + f"batch_size={batch_size}, num_rows={num_rows}, num_experts={num_experts}, in_features={in_features}, hidden_features={hidden_features}" + ) + rt = MoE( + batch_size=batch_size, + num_rows=num_rows, + num_experts=num_experts, + in_features=in_features, + hidden_features=hidden_features, + out_features=in_features, + ) + rt.parity_check() + + @pytest.mark.slow + def test_moe_benchmark(self): + for batch_size in [32, 64]: + for num_rows in [128, 512]: + for num_experts in [64, 128]: + for in_features in [256, 512]: + for hidden_features in [1024, 2048]: + print( + f"batch_size={batch_size}, num_rows={num_rows}, num_experts={num_experts}, in_features={in_features}, hidden_features={hidden_features}" + ) + rt = MoE( + batch_size=batch_size, + num_rows=num_rows, + num_experts=num_experts, + in_features=in_features, + hidden_features=hidden_features, + out_features=in_features, + ) + rt.benchmark() + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py b/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py index b17ae5f69aff5..cf8128e0eebcf 100644 --- a/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py +++ b/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py @@ -261,14 +261,15 @@ def get_eps(self): eps = ["CPUExecutionProvider", "CUDAExecutionProvider"] return list(filter(lambda ep: ep in ort.get_available_providers(), eps)) - def run_ort_ep_tests(self, onnx_graph, inputs_ort, expected_output_bsnh): + def run_ort_ep_tests(self, onnx_graph, inputs_ort, expected_output_bsnh, transposed=False): eps = self.get_eps() for ep in eps: sess = ort.InferenceSession(onnx_graph, providers=[ep]) output_ort = sess.run(None, inputs_ort)[0] - output_ort = output_ort.reshape( - (self.config.batch_size, inputs_ort["input"].shape[1], self.config.num_heads, self.config.head_size) - ) + if not transposed: + output_ort = output_ort.reshape( + (self.config.batch_size, inputs_ort["input"].shape[1], self.config.num_heads, self.config.head_size) + ) # Compare outputs as BxSxNxH self.assertTrue(np.allclose(expected_output_bsnh, output_ort)) @@ -445,6 +446,44 @@ def test_hf_token_rotary_one_pos_id(self): # Compare outputs as BxSxNxH self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy()) + # Bonus test: Prompt step, interleaved = false, pos ids shape = (1), transposed + def test_hf_prompt_rotary_one_pos_id_transposed(self): + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_hf = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_hf) # output is BxNxSxH + + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = torch.tensor([0]) + onnx_graph = self.create_onnx_graph(x_bnsh.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bnsh.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Compare outputs as BxNxSxH + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.detach().cpu().numpy(), transposed=True) + + # Bonus test: Token generation step, interleaved = false, pos ids shape = (1), transposed + def test_hf_token_rotary_one_pos_id_transposed(self): + x_bnsh = torch.randn(self.config.batch_size, self.config.num_heads, 1, self.config.head_size) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_ids = torch.stack([torch.tensor([2]) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_ids) # output is BxSxNxH + + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = torch.tensor([2]) + onnx_graph = self.create_onnx_graph(x_bnsh.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bnsh.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Set tranposed=True to compare outputs as BxSxNxH + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.detach().cpu().numpy(), transposed=True) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/testdata/transform/gh_issue_18338.onnx b/onnxruntime/test/testdata/transform/gh_issue_18338.onnx new file mode 100644 index 0000000000000..afb499a347ec7 Binary files /dev/null and b/onnxruntime/test/testdata/transform/gh_issue_18338.onnx differ diff --git a/onnxruntime/test/testdata/transform/gh_issue_18338.py b/onnxruntime/test/testdata/transform/gh_issue_18338.py new file mode 100644 index 0000000000000..dc5446ac56c09 --- /dev/null +++ b/onnxruntime/test/testdata/transform/gh_issue_18338.py @@ -0,0 +1,859 @@ +import google.protobuf.text_format +import onnx +from numpy import array, float16 + +import onnxruntime as ort + +# Run n times +N = 1 + +onnx_model_text = """ +ir_version: 8 +producer_name: "pytorch" +producer_version: "2.2.0" +graph { + node { + output: "_val_1" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value_ints" + ints: -1 + type: INTS + } + doc_string: "" + } + node { + input: "input_0" + input: "_val_1" + output: "_val_2" + name: "Reshape_1" + op_type: "Reshape" + attribute { + name: "allowzero" + i: 0 + type: INT + } + doc_string: "" + } + node { + input: "_val_2" + output: "_val_3" + name: "_aten_linalg_vector_norm_no_dim_onnx_2" + op_type: "_aten_linalg_vector_norm_no_dim_onnx" + attribute { + name: "keepdim" + i: 0 + type: INT + } + attribute { + name: "ord" + f: 2.0 + type: FLOAT + } + doc_string: "" + domain: "pkg.onnxscript.torch_lib" + } + name: "main_graph" + input { + name: "input_0" + type { + tensor_type { + elem_type: 10 + shape { + } + } + } + } + output { + name: "_val_3" + type { + tensor_type { + elem_type: 10 + shape { + } + } + } + } + value_info { + name: "_val_1" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 1 + } + } + } + } + } + value_info { + name: "_val_2" + type { + tensor_type { + elem_type: 10 + shape { + dim { + dim_value: 1 + } + } + } + } + } +} +opset_import { + domain: "pkg.onnxscript.torch_lib" + version: 1 +} +opset_import { + domain: "" + version: 18 +} +opset_import { + domain: "pkg.onnxscript.torch_lib.common" + version: 1 +} +functions { + name: "_aten_linalg_vector_norm_no_dim_onnx" + input: "self" + output: "result_29" + attribute: "ord" + attribute: "keepdim" + node { + input: "self" + output: "tmp" + name: "n0" + op_type: "Shape" + domain: "" + } + node { + input: "tmp" + output: "self_rank" + name: "n1" + op_type: "Size" + domain: "" + } + node { + output: "int64_0" + name: "n2" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 7 + int64_data: 0 + name: "int64_0" + } + type: TENSOR + } + domain: "" + } + node { + input: "int64_0" + input: "self_rank" + output: "int64_0_cast" + name: "n3" + op_type: "CastLike" + domain: "" + } + node { + input: "self_rank" + input: "int64_0_cast" + output: "cond" + name: "n4" + op_type: "Equal" + domain: "" + } + node { + input: "cond" + output: "self_2" + name: "n5" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + output: "int64_0_1d" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + int64_data: 0 + name: "int64_0_1d" + } + type: TENSOR + } + domain: "" + } + node { + input: "self" + input: "int64_0_1d" + output: "self_0" + name: "n1" + op_type: "Unsqueeze" + domain: "" + } + name: "thenGraph_4" + output { + name: "self_0" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "self" + output: "self_1" + name: "n0" + op_type: "Identity" + domain: "" + } + name: "elseGraph_4" + output { + name: "self_1" + type { + } + } + } + type: GRAPH + } + domain: "" + } + node { + input: "self_2" + output: "self_3" + name: "n6" + op_type: "Abs" + domain: "" + } + node { + output: "ord" + name: "n7" + op_type: "Constant" + attribute { + name: "value_float" + type: FLOAT + ref_attr_name: "ord" + } + domain: "" + } + node { + input: "ord" + output: "ord_4" + name: "n8" + op_type: "Cast" + attribute { + name: "to" + i: 1 + type: INT + } + domain: "" + } + node { + input: "ord_4" + output: "cond_5" + name: "n9" + op_type: "IsInf" + attribute { + name: "detect_negative" + i: 0 + type: INT + } + attribute { + name: "detect_positive" + i: 1 + type: INT + } + domain: "" + } + node { + input: "cond_5" + output: "result_24" + name: "n10" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result" + name: "n0" + op_type: "ReduceMax" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_9" + output { + name: "result" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "ord_4" + output: "cond_6" + name: "n0" + op_type: "IsInf" + attribute { + name: "detect_negative" + i: 1 + type: INT + } + attribute { + name: "detect_positive" + i: 0 + type: INT + } + domain: "" + } + node { + input: "cond_6" + output: "result_23" + name: "n1" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result_7" + name: "n0" + op_type: "ReduceMin" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_11" + output { + name: "result_7" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + output: "const" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 0.0 + name: "const" + } + type: TENSOR + } + domain: "" + } + node { + input: "const" + input: "ord_4" + output: "const_cast" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "ord_4" + input: "const_cast" + output: "cond_8" + name: "n2" + op_type: "Equal" + domain: "" + } + node { + input: "cond_8" + output: "result_22" + name: "n3" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "self_bool" + name: "n0" + op_type: "Cast" + attribute { + name: "to" + i: 9 + type: INT + } + domain: "" + } + node { + input: "self_bool" + input: "self_3" + output: "self_0_1" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "self_0_1" + output: "result_9" + name: "n2" + op_type: "ReduceSum" + attribute { + name: "keepdims" + i: 0 + type: INT + } + domain: "" + } + name: "thenGraph_13" + output { + name: "result_9" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + output: "const_10" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 1.0 + name: "const_10" + } + type: TENSOR + } + domain: "" + } + node { + input: "const_10" + input: "ord_4" + output: "const_10_cast" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "ord_4" + input: "const_10_cast" + output: "cond_11" + name: "n2" + op_type: "Equal" + domain: "" + } + node { + input: "cond_11" + output: "result_21" + name: "n3" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result_12" + name: "n0" + op_type: "ReduceL1" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_18" + output { + name: "result_12" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + output: "const_13" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 2.0 + name: "const_13" + } + type: TENSOR + } + domain: "" + } + node { + input: "const_13" + input: "ord_4" + output: "const_13_cast" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "ord_4" + input: "const_13_cast" + output: "cond_14" + name: "n2" + op_type: "Equal" + domain: "" + } + node { + input: "cond_14" + output: "result_20" + name: "n3" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result_15" + name: "n0" + op_type: "ReduceL2" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_20" + output { + name: "result_15" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "ord_4" + input: "self_3" + output: "ord_float" + name: "n0" + op_type: "CastLike" + domain: "" + } + node { + input: "self_3" + input: "ord_float" + output: "self_pow" + name: "n1" + op_type: "Pow" + domain: "" + } + node { + input: "self_pow" + output: "tmp_16" + name: "n2" + op_type: "ReduceSum" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + node { + output: "const_17" + name: "n3" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 1.0 + name: "const_17" + } + type: TENSOR + } + domain: "" + } + node { + input: "const_17" + input: "ord_float" + output: "const_17_cast" + name: "n4" + op_type: "CastLike" + domain: "" + } + node { + input: "const_17_cast" + input: "ord_float" + output: "tmp_18" + name: "n5" + op_type: "Div" + domain: "" + } + node { + input: "tmp_16" + input: "tmp_18" + output: "result_19" + name: "n6" + op_type: "Pow" + domain: "" + } + name: "elseGraph_20" + output { + name: "result_19" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_18" + output { + name: "result_20" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_13" + output { + name: "result_21" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_11" + output { + name: "result_22" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_9" + output { + name: "result_23" + type { + } + } + } + type: GRAPH + } + domain: "" + } + node { + output: "int64_0_25" + name: "n11" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 7 + int64_data: 0 + name: "int64_0_25" + } + type: TENSOR + } + domain: "" + } + node { + input: "int64_0_25" + input: "self_rank" + output: "int64_0_25_cast" + name: "n12" + op_type: "CastLike" + domain: "" + } + node { + input: "self_rank" + input: "int64_0_25_cast" + output: "cond_26" + name: "n13" + op_type: "Equal" + domain: "" + } + node { + input: "cond_26" + output: "result_29" + name: "n14" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "result_24" + output: "result_27" + name: "n0" + op_type: "Squeeze" + domain: "" + } + name: "thenGraph_27" + output { + name: "result_27" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "result_24" + output: "result_28" + name: "n0" + op_type: "Identity" + domain: "" + } + name: "elseGraph_27" + output { + name: "result_28" + type { + } + } + } + type: GRAPH + } + domain: "" + } + opset_import { + domain: "" + version: 18 + } + domain: "pkg.onnxscript.torch_lib" +} +functions { + name: "Rank" + input: "input" + output: "return_val" + node { + input: "input" + output: "tmp" + name: "n0" + op_type: "Shape" + domain: "" + } + node { + input: "tmp" + output: "return_val" + name: "n1" + op_type: "Size" + domain: "" + } + doc_string: "Take the rank of the input tensor." + opset_import { + domain: "" + version: 18 + } + domain: "pkg.onnxscript.torch_lib.common" +} +functions { + name: "IsScalar" + input: "input" + output: "return_val" + node { + input: "input" + output: "tmp" + name: "n0" + op_type: "Shape" + domain: "" + } + node { + input: "tmp" + output: "tmp_0" + name: "n1" + op_type: "Size" + domain: "" + } + node { + output: "tmp_1" + name: "n2" + op_type: "Constant" + attribute { + name: "value_int" + i: 0 + type: INT + } + domain: "" + } + node { + input: "tmp_0" + input: "tmp_1" + output: "return_val" + name: "n3" + op_type: "Equal" + domain: "" + } + doc_string: "Return whether the input has rank 0, or is a scalar." + opset_import { + domain: "" + version: 18 + } + domain: "pkg.onnxscript.torch_lib.common" +} + +""" + +ort_inputs = {"input_0": array(0.8965, dtype=float16)} + +# Set up the inference session +session_options = ort.SessionOptions() +session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL +onnx_model = onnx.ModelProto() +google.protobuf.text_format.Parse(onnx_model_text, onnx_model) + +# Uncomment this line to save the model to a file for examination +# onnx.save_model(onnx_model, "test_output_match_opinfo__linalg_vector_norm_cpu_float16.onnx") + +onnx.checker.check_model(onnx_model) +session = ort.InferenceSession(onnx_model.SerializeToString(), session_options, providers=("CPUExecutionProvider",)) + +# Run the model +for _ in range(N): + ort_outputs = session.run(None, ort_inputs) diff --git a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc new file mode 100644 index 0000000000000..0412000e04e1b --- /dev/null +++ b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/optimizer/initializer.h" +#include "orttraining/core/optimizer/conv1d_replacement.h" +#include "core/graph/graph_utils.h" + +/* + In LoRA code, it will use conv1d to do projection for qkv, + while the conv1d calculation is mathematically equivalent to MatMul, and MatMul is much faster than conv1d in GPU. + The graph transformation is doing the following graph substitution: + 1. The input graph is: + conv_input conv_weight + \ / + \ / + conv1d + + 2. The output graph is as follows, + the number of MatMul is equal to attribute "group" of conv1d + conv_input conv1d.group conv_weight conv1d.group + \ / \ / + \ / Squeeze / + \ / \ / + Split Split + / / ... \ / / ... \ + / / ... \ / / ... \ + / / ... \ / / ... \ + input0 input1 ... inputN weight0 weight1 ... weightN + \ \ \ / / / + \ \ \ / / / + \ \ \ / / / + \ \ X / / + \ \ / \ / / + \ \ / X / + \ X / \ / + \ / \ / \ / + MatMul MatMul ... MatMul + \ | ... / + \ | / + \ | / +*/ +namespace onnxruntime { +bool NodeCanBeReplacedByMatmul(const Node& node) { + // If node type is Conv, and attr "dilations" is 1, "kernel_shape" is 1, "stride" is 1, group is 1 or 2, + // then it can be replaced by MatMul + // Kernel_shape is 1 means it is conv1d + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11})) { + return false; + } + const auto* dilations = graph_utils::GetNodeAttribute(node, "dilations"); + const auto* kernel_shape = graph_utils::GetNodeAttribute(node, "kernel_shape"); + const auto* stride = graph_utils::GetNodeAttribute(node, "strides"); + const auto* group = graph_utils::GetNodeAttribute(node, "group"); + if (dilations == nullptr || kernel_shape == nullptr || stride == nullptr || group == nullptr) { + return false; + } + if ((dilations->ints_size() && dilations->ints(0) != 1) || + (kernel_shape->ints_size() && kernel_shape->ints(0) != 1) || + (stride->ints_size() && stride->ints(0) != 1) || + group->i() >= 3) { + return false; + } + + return true; +} + +void Conv1dToMatmul(Graph& graph, Node& conv) { + // Shape of conv1d input: [batch_size, in_channels, in_length] + // Shape of conv1d weight:[output_channels, input_channels/group, kernel_shape], kernel_shape is 1 + // We need to split the input into "group", and squeeze&split the weight, and then do MatMul + const std::string node_description("Conv1dReplacement"); + auto execution_provider_type = conv.GetExecutionProviderType(); + // 1. Split conv input + auto group_attr = graph_utils::GetNodeAttribute(conv, "group"); + int64_t group_num = 1; // default group is 1 from ONNX schema + if (group_attr != nullptr) { + group_num = group_attr->i(); + } + auto conv1d_input = conv.MutableInputDefs()[0]; + std::vector conv1d_input_splitted_outputs; + for (int i = 0; i < group_num; i++) { + conv1d_input_splitted_outputs.push_back(&graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("input_split_output"), nullptr)); + } + auto& input_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, {conv1d_input}, + {conv1d_input_splitted_outputs}); + input_split.SetExecutionProviderType(execution_provider_type); + input_split.AddAttribute("axis", int64_t(1)); + auto onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + if (onnx_opset_version >= 18) { + input_split.AddAttribute("num_outputs", group_num); + } + // 2. Squeeze conv weight + auto conv1d_weight = conv.MutableInputDefs()[1]; + auto weight_squeeze_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("weight_squeeze_output"), nullptr); + auto& weight_squeeze = graph.AddNode(graph.GenerateNodeName("WeightSqueeze"), "Squeeze", + node_description, {conv1d_weight}, {weight_squeeze_output}); + if (onnx_opset_version > 12) { + // After onnx version 12, squeeze node has axes as input instead of attribute + ONNX_NAMESPACE::TensorProto initializer_proto; + initializer_proto.set_name(graph.GenerateNodeName("ConstAsInitializer")); + initializer_proto.add_dims(static_cast(1)); + initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + InlinedVector initializer_proto_value{2}; + initializer_proto.set_raw_data(initializer_proto_value.data(), initializer_proto_value.size() * sizeof(int64_t)); + auto& axes_input = graph_utils::AddInitializer(graph, initializer_proto); + // Squeeze node doesn't have opschema here, so we need to set input args count manually + weight_squeeze.MutableInputArgsCount().resize(2); + graph_utils::AddNodeInput(weight_squeeze, 1, axes_input); + } else { + weight_squeeze.AddAttribute("axes", std::vector{2}); + } + weight_squeeze.SetExecutionProviderType(execution_provider_type); + // 3. Split conv weight + std::vector conv1d_weight_splitted_outputs; + for (int i = 0; i < group_num; i++) { + conv1d_weight_splitted_outputs.push_back(&graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("weight_split_output"), nullptr)); + } + auto& weight_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, + {weight_squeeze_output}, {conv1d_weight_splitted_outputs}); + weight_split.AddAttribute("axis", int64_t(0)); + weight_split.SetExecutionProviderType(execution_provider_type); + if (onnx_opset_version >= 18) { + weight_split.AddAttribute("num_outputs", group_num); + } + // 4. Do MatMul + std::vector matmul_outputs; + for (int i = 0; i < group_num; i++) { + auto matmul_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("matmul_output"), nullptr); + matmul_outputs.push_back(matmul_output); + auto& matmul = graph.AddNode(graph.GenerateNodeName("Matmul"), "MatMul", node_description, + {conv1d_weight_splitted_outputs[i], conv1d_input_splitted_outputs[i]}, + {matmul_output}); + matmul.SetExecutionProviderType(execution_provider_type); + } + // 5. Concat matmul outputs + auto& concat_node = graph.AddNode(graph.GenerateNodeName("Concat"), "Concat", node_description, + matmul_outputs, {}); + concat_node.SetExecutionProviderType(execution_provider_type); + concat_node.AddAttribute("axis", int64_t(1)); + // 6. Clean up - delted original "conv" node, its output is replaced by concat_node + graph_utils::FinalizeNodeFusion(graph, concat_node, conv); +} + +Status Conv1dReplacement::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + for (auto node_index : node_topology_list) { + auto* node_ptr = graph.GetNode(node_index); + if (!node_ptr) + continue; // node was removed + auto& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + if (NodeCanBeReplacedByMatmul(node)) { + LOGS(logger, VERBOSE) << "lora conv1d replacement, node name: " + node.Name(); + Conv1dToMatmul(graph, node); + modified = true; + } + } + return Status::OK(); +} +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/conv1d_replacement.h b/orttraining/orttraining/core/optimizer/conv1d_replacement.h new file mode 100644 index 0000000000000..740f13c76fd6f --- /dev/null +++ b/orttraining/orttraining/core/optimizer/conv1d_replacement.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +class Conv1dReplacement : public GraphTransformer { + public: + Conv1dReplacement(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("Conv1dReplacement", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 57d76577f1ba7..6193a1d10c095 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -72,6 +72,7 @@ #ifdef ENABLE_TRAINING_TORCH_INTEROP #include "orttraining/core/optimizer/pythonop_rewriter.h" #endif +#include "orttraining/core/optimizer/conv1d_replacement.h" namespace onnxruntime { namespace training { @@ -194,6 +195,7 @@ std::vector> GeneratePreTrainingTransformers( // Once we have a CPU kernel for PadAndUnflatten, we can remove the guard. transformers.emplace_back(std::make_unique(compatible_eps, config.sparse_embedding_input_names)); + transformers.emplace_back(std::make_unique(compatible_eps)); #endif } diff --git a/orttraining/orttraining/python/checkpointing_utils.py b/orttraining/orttraining/python/checkpointing_utils.py deleted file mode 100644 index 460b9982297d1..0000000000000 --- a/orttraining/orttraining/python/checkpointing_utils.py +++ /dev/null @@ -1,127 +0,0 @@ -import os - -import torch - - -def list_checkpoint_files(checkpoint_dir, checkpoint_prefix, extension=".ort.pt"): - ckpt_file_names = [f for f in os.listdir(checkpoint_dir) if f.startswith(checkpoint_prefix)] - ckpt_file_names = [f for f in ckpt_file_names if f.endswith(extension)] - ckpt_file_names = [os.path.join(checkpoint_dir, f) for f in ckpt_file_names] - - assert len(ckpt_file_names) > 0, 'No checkpoint files found with prefix "{}" in directory {}.'.format( - checkpoint_prefix, checkpoint_dir - ) - return ckpt_file_names - - -def get_checkpoint_name(prefix, is_partitioned, world_rank=None, world_size=None): - SINGLE_CHECKPOINT_FILENAME = "{prefix}.ort.pt" # noqa: N806 - MULTIPLE_CHECKPOINT_FILENAME = "{prefix}.ZeRO.{world_rank}.{world_size}.ort.pt" # noqa: N806 - - if is_partitioned: - filename = MULTIPLE_CHECKPOINT_FILENAME.format( - prefix=prefix, world_rank=world_rank, world_size=(world_size - 1) - ) - else: - filename = SINGLE_CHECKPOINT_FILENAME.format(prefix=prefix) - - return filename - - -def _split_state_dict(state_dict): - optimizer_keys = ["Moment_1_", "Moment_2_", "Update_Count_", "Step"] - split_sd = {"optimizer": {}, "fp32_param": {}, "fp16_param": {}} - for k, v in state_dict.items(): - mode = "fp32_param" - for optim_key in optimizer_keys: - if k.startswith(optim_key): - mode = "optimizer" - break - if k.endswith("_fp16"): - mode = "fp16_param" - split_sd[mode][k] = v - return split_sd - - -class CombineZeroCheckpoint: - def __init__(self, checkpoint_files, clean_state_dict=None): - assert len(checkpoint_files) > 0, "No checkpoint files passed" - self.checkpoint_files = checkpoint_files - self.clean_state_dict = clean_state_dict - self.world_size = int(self.checkpoint_files[0].split("ZeRO")[1].split(".")[2]) + 1 - assert len(self.checkpoint_files) == self.world_size, f"Could not find {self.world_size} files" - self.weight_shape_map = dict() - self.sharded_params = set() - - def _split_name(self, name: str): - name_split = name.split("_view_") - view_num = None - if len(name_split) > 1: - view_num = int(name_split[1]) - optimizer_key = "" - mp_suffix = "" - if name_split[0].startswith("Moment_1"): - optimizer_key = "Moment_1_" - elif name_split[0].startswith("Moment_2"): - optimizer_key = "Moment_2_" - elif name_split[0].startswith("Update_Count"): - optimizer_key = "Update_Count_" - elif name_split[0].endswith("_fp16"): - mp_suffix = "_fp16" - param_name = name_split[0] - if optimizer_key: - param_name = param_name.split(optimizer_key)[1] - param_name = param_name.split("_fp16")[0] - return param_name, optimizer_key, view_num, mp_suffix - - def _update_weight_statistics(self, name, value): - if name not in self.weight_shape_map: - self.weight_shape_map[name] = value.size() # original shape of tensor - - def _reshape_tensor(self, key): - value = self.aggregate_state_dict[key] - weight_name, _, _, _ = self._split_name(key) - set_size = self.weight_shape_map[weight_name] - self.aggregate_state_dict[key] = value.reshape(set_size) - - def _aggregate(self, param_dict): - for k, v in param_dict.items(): - weight_name, optimizer_key, view_num, mp_suffix = self._split_name(k) - if view_num is not None: - # parameter is sharded - param_name = optimizer_key + weight_name + mp_suffix - - if param_name in self.aggregate_state_dict and optimizer_key not in ["Update_Count_"]: - self.sharded_params.add(param_name) - # Found a previous shard of the param, concatenate shards ordered by ranks - self.aggregate_state_dict[param_name] = torch.cat((self.aggregate_state_dict[param_name], v)) - else: - self.aggregate_state_dict[param_name] = v - else: - if k in self.aggregate_state_dict: - assert (self.aggregate_state_dict[k] == v).all(), "Unsharded params must have the same value" - else: - self.aggregate_state_dict[k] = v - self._update_weight_statistics(weight_name, v) - - def aggregate_checkpoints(self): - checkpoint_prefix = self.checkpoint_files[0].split(".ZeRO")[0] - self.aggregate_state_dict = dict() - - for i in range(self.world_size): - checkpoint_name = get_checkpoint_name(checkpoint_prefix, True, i, self.world_size) - rank_state_dict = torch.load(checkpoint_name, map_location=torch.device("cpu")) - if "model" in rank_state_dict: - rank_state_dict = rank_state_dict["model"] - - if self.clean_state_dict: - rank_state_dict = self.clean_state_dict(rank_state_dict) - - rank_state_dict = _split_state_dict(rank_state_dict) - self._aggregate(rank_state_dict["fp16_param"]) - self._aggregate(rank_state_dict["fp32_param"]) - self._aggregate(rank_state_dict["optimizer"]) - - for k in self.sharded_params: - self._reshape_tensor(k) - return self.aggregate_state_dict diff --git a/orttraining/orttraining/python/deprecated/__init__.py b/orttraining/orttraining/python/deprecated/__init__.py deleted file mode 100644 index 6e02db707bc47..0000000000000 --- a/orttraining/orttraining/python/deprecated/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -from onnxruntime.capi._pybind_state import TrainingParameters # noqa: F401 -from onnxruntime.capi.training.training_session import TrainingSession # noqa: F401 diff --git a/orttraining/orttraining/python/deprecated/training_session.py b/orttraining/orttraining/python/deprecated/training_session.py deleted file mode 100644 index a6900578e174b..0000000000000 --- a/orttraining/orttraining/python/deprecated/training_session.py +++ /dev/null @@ -1,68 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import os # noqa: F401 -import sys # noqa: F401 - -from onnxruntime.capi import _pybind_state as C -from onnxruntime.capi.onnxruntime_inference_collection import IOBinding # noqa: F401 -from onnxruntime.capi.onnxruntime_inference_collection import ( - InferenceSession, - Session, - check_and_normalize_provider_args, -) - - -class TrainingSession(InferenceSession): - def __init__(self, path_or_bytes, parameters, sess_options=None, providers=None, provider_options=None): - Session.__init__(self) - - if sess_options: - self._sess = C.TrainingSession(sess_options) - else: - self._sess = C.TrainingSession() - - # providers needs to be passed explicitly as of ORT 1.10 - # retain the pre-1.10 behavior by setting to the available providers. - if providers is None: - providers = C.get_available_providers() - - providers, provider_options = check_and_normalize_provider_args( - providers, provider_options, C.get_available_providers() - ) - - if isinstance(path_or_bytes, str): - config_result = self._sess.load_model(path_or_bytes, parameters, providers, provider_options) - elif isinstance(path_or_bytes, bytes): - config_result = self._sess.read_bytes(path_or_bytes, parameters, providers, provider_options) - else: - raise TypeError(f"Unable to load from type '{type(path_or_bytes)}'") - - self.loss_scale_input_name = config_result.loss_scale_input_name - - self._inputs_meta = self._sess.inputs_meta - self._outputs_meta = self._sess.outputs_meta - - def __del__(self): - if self._sess: - self._sess.finalize() - - def get_state(self): - return self._sess.get_state() - - def get_model_state(self, include_mixed_precision_weights=False): - return self._sess.get_model_state(include_mixed_precision_weights) - - def get_optimizer_state(self): - return self._sess.get_optimizer_state() - - def get_partition_info_map(self): - return self._sess.get_partition_info_map() - - def load_state(self, dict, strict=False): - self._sess.load_state(dict, strict) - - def is_output_fp32_node(self, output_name): - return self._sess.is_output_fp32_node(output_name) diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py deleted file mode 100644 index 5286c087cfb64..0000000000000 --- a/orttraining/orttraining/python/ort_trainer.py +++ /dev/null @@ -1,1241 +0,0 @@ -import io -import os -import warnings - -import numpy as np -import onnx -import torch -import torch.nn -import torch.onnx -from onnx import helper, numpy_helper -from packaging.version import Version as LooseVersion - -import onnxruntime as ort -import onnxruntime.capi.pt_patch -from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference - -from ..training import postprocess -from .checkpointing_utils import CombineZeroCheckpoint, get_checkpoint_name, list_checkpoint_files - -DEFAULT_OPSET_VERSION = 14 - - -class IODescription: - def __init__(self, name, shape, dtype=None, num_classes=None): - self.name_ = name - self.shape_ = shape - self.dtype_ = dtype - self.num_classes_ = num_classes - - -class ModelDescription: - def __init__(self, inputs, outputs): - self.inputs_ = inputs - self.outputs_ = outputs - - -def resolve_symbolic_dimensions(inputs, input_descs, output_descs): - import copy - - output_descs_copy = copy.deepcopy(output_descs) - resolved_dims = {} - for input, input_desc in zip(inputs, input_descs): - for i, axis in enumerate(input_desc.shape_): - if isinstance(axis, str): - resolved_dims[axis] = input.size()[i] - - for output_desc in output_descs_copy: - for i, axis in enumerate(output_desc.shape_): - if isinstance(axis, str): - output_desc.shape_[i] = resolved_dims[axis] - - if any(isinstance(axis, str) for axis in output_desc.shape_ for output_desc in output_descs): - raise RuntimeError("Cannot run model with unknown output dimensions") - - return output_descs_copy - - -def generate_sample(desc, device=None): - # symbolic dimensions are described with strings. set symbolic dimensions to be 1 - size = [s if isinstance(s, (int)) else 1 for s in desc.shape_] - if desc.num_classes_: - return torch.randint(0, desc.num_classes_, size, dtype=desc.dtype_).to(device) - else: - return torch.randn(size, dtype=desc.dtype_).to(device) - - -def get_device_index(device): - if type(device) == str: # noqa: E721 - # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 - device = torch.device(device) - return 0 if device.index is None else device.index - - -def input_get_device_index(input): - if isinstance(input, (list, tuple)): - device_index = get_device_index(input[0].device) - else: - device_index = get_device_index(input.device) - - return device_index - - -def get_all_gradients_finite_arg_name(session): - all_fp16_or_fp32_gradients_finite_node_args = [x for x in session._outputs_meta if "all_gradients_finite" in x.name] - if len(all_fp16_or_fp32_gradients_finite_node_args) < 1: - raise RuntimeError( - "Failed to find a group NodeArg with name that matches 'all_gradients_finite'\ - from the training session." - ) - - return all_fp16_or_fp32_gradients_finite_node_args[0].name - - -def get_group_accumulated_gradients_output_node_arg_name(session): - # TODO: get the constant string via pybind. - # optimizer_graph_builder BuildGroupNode with fixed string: 'Group_Accumulated_Gradients' - accumulated_gradients_output_node_args = [ - x for x in session._outputs_meta if "Group_Accumulated_Gradients" in x.name - ] - if len(accumulated_gradients_output_node_args) != 1: - raise RuntimeError( - "Failed to find a group NodeArg with name that matches 'Group_Accumulated_Gradients'\ - from the training session." - ) - - return accumulated_gradients_output_node_args[0].name - - -def ort_training_session_run_helper(session, iobinding, inputs, input_descs, output_descs, device, run_options=None): - for input, input_desc in zip(inputs, input_descs): - device_index = input_get_device_index(input) - iobinding.bind_input( - input_desc.name_, - input.device.type, - device_index, - dtype_torch_to_numpy(input.dtype), - list(input.size()), - input.data_ptr(), - ) - - output_descs_resolved = resolve_symbolic_dimensions(inputs, input_descs, output_descs) - torch_outputs = {} - for output_desc in output_descs_resolved: - torch_tensor = torch.zeros( - output_desc.shape_, - device=device, - dtype=output_desc.eval_dtype_ if hasattr(output_desc, "eval_dtype_") else output_desc.dtype_, - ) - iobinding.bind_output( - output_desc.name_, - torch_tensor.device.type, - get_device_index(device), - dtype_torch_to_numpy(torch_tensor.dtype), - list(torch_tensor.size()), - torch_tensor.data_ptr(), - ) - torch_outputs[output_desc.name_] = torch_tensor - - session.run_with_iobinding(iobinding, run_options) - return torch_outputs - - -def FuseSofmaxNLLToSoftmaxCE(onnx_model): # noqa: N802 - nll_count = 0 - while True: - nll_count = nll_count + 1 - nll_loss_node = None - nll_loss_node_index = 0 - for nll_loss_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007 - if node.op_type == "nll_loss" or node.op_type == "NegativeLogLikelihoodLoss": - nll_loss_node = node - break - - if nll_loss_node is None: - break - - softmax_node = None - softmax_node_index = 0 - label_input_name = None - weight_input_name = None - for softmax_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007 - if node.op_type == "LogSoftmax": - # has to be connected to nll_loss - if len(nll_loss_node.input) > 2: - weight_input_name = nll_loss_node.input[2] - if node.output[0] == nll_loss_node.input[0]: - softmax_node = node - label_input_name = nll_loss_node.input[1] - break - elif node.output[0] == nll_loss_node.input[1]: - softmax_node = node - label_input_name = nll_loss_node.input[0] - break - else: - if softmax_node is not None: - break - - if softmax_node is None: - break - - # delete nll_loss and LogSoftmax nodes in order - if nll_loss_node_index < softmax_node_index: - del onnx_model.graph.node[softmax_node_index] - del onnx_model.graph.node[nll_loss_node_index] - else: - del onnx_model.graph.node[nll_loss_node_index] - del onnx_model.graph.node[softmax_node_index] - - probability_output_name = softmax_node.output[0] - node = onnx_model.graph.node.add() - inputs = ( - [softmax_node.input[0], label_input_name, weight_input_name] - if weight_input_name - else [softmax_node.input[0], label_input_name] - ) - node.CopyFrom( - onnx.helper.make_node( - "SparseSoftmaxCrossEntropy", - inputs, - [nll_loss_node.output[0], probability_output_name], - "nll_loss_node_" + str(nll_count), - ) - ) - - return onnx_model - - -def delete_input_with_name(input, name): - index = 0 - for i in input: - if i.name == name: - del input[index] - break - index = index + 1 - - -# reference: -# https://docs.scipy.org/doc/numpy-1.13.0/user/basics.types.html -# https://pytorch.org/docs/stable/tensors.html -# also must map to types accepted by: -# MLDataType NumpyTypeToOnnxRuntimeType(int numpy_type) -def dtype_torch_to_numpy(torch_dtype): - if torch_dtype == torch.float64 or torch_dtype == torch.double: - return np.float64 - elif torch_dtype == torch.float32 or torch_dtype == torch.float: - return np.float32 - elif torch_dtype == torch.float16 or torch_dtype == torch.half: - return np.float16 - elif torch_dtype == torch.int64 or torch_dtype == torch.long: - return np.longlong - elif torch_dtype == torch.int32 or torch_dtype == torch.int: - return np.int32 - elif torch_dtype == torch.int16 or torch_dtype == torch.short: - return np.int16 - elif torch_dtype == torch.bool: - return bool - else: - raise Exception("Torch type to numpy type mapping unavailable for: " + str(torch_dtype)) - - -class model_loss_cls(torch.nn.Module): # noqa: N801 - def __init__(self, model, loss_fn): - super().__init__() - self.model_ = model - self.loss_fn_ = loss_fn - - def forward(self, *inputs): - # here we assume input can be unpacked into input and label - input, label = inputs[:-1], inputs[-1] - preds = self.model_(*input) - return self.loss_fn_(preds, label), preds - - -class WrapModel(torch.nn.Module): - def __init__(self, model, loss_fn, input_names): - super().__init__() - self.model_ = model - self.loss_fn_ = loss_fn - self.input_names_ = input_names - - def forward(self, *inputs): - import inspect - - # *inputs is given by torch trace. It is in the order of input_names. - # model_ takes input in a order (which can be obtained via inspect.signature(model.forward)) different than input_names. - sig = inspect.signature(self.model_.forward) - list(sig.parameters.keys()) - - input_dict = {} - for key in sig.parameters: - if key in self.input_names_: - input_dict[key] = inputs[self.input_names_.index(key)] - - model_out = self.model_(**input_dict) - if self.loss_fn_ is None: - return model_out - - label = inputs[-1] - preds = model_out - return self.loss_fn_(preds, label), preds - - -def wrap_for_input_match(model, loss_fn, input_names): - import inspect - - sig = inspect.signature(model.forward) - ordered_list_keys = list(sig.parameters.keys()) - if loss_fn: - sig_loss = inspect.signature(loss_fn) - if len(sig_loss.parameters) != 2: - raise RuntimeError("loss function should take two arguments - predict and label.") - - # label shall be the second input to loss_fn. - ordered_list_keys = [*ordered_list_keys, list(sig_loss.parameters.keys())[1]] - - # name match is needed only when input_names are a subset - # of expected inputs (inputs to model and loss_fn combined). - if len(input_names) > len(ordered_list_keys): - # this is likely the case where input arguments are packed. - # TODO: to unpack the input argument. - return model_loss_cls(model, loss_fn) if loss_fn else model - elif len(input_names) == len(ordered_list_keys): - # in this case, we do not require name match. - return model_loss_cls(model, loss_fn) if loss_fn else model - - if not all(x in ordered_list_keys for x in input_names): - # model desc has name(s) not matching the model signature. We cannot do anything in this case. - # better to warning the user. - return model_loss_cls(model, loss_fn) if loss_fn else model - - # if input_names match ordered_list_keys, there is not need for wrapping - match = True - for i, input_name in enumerate(input_names): - if input_name != ordered_list_keys[i]: - match = False - break - - if match: - return model_loss_cls(model, loss_fn) if loss_fn else model - - model = WrapModel(model, loss_fn, input_names) - - return model - - -def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, opset_version=DEFAULT_OPSET_VERSION): - # example: {input0:{0:'batch'}, input1:{0:'batch'}} - dynamic_axes = {} - for input in model_desc.inputs_: - symbolic_axis = {} - for i, axis in enumerate(input.shape_): - if isinstance(axis, str): - symbolic_axis[i] = axis - if len(symbolic_axis): - dynamic_axes[input.name_] = symbolic_axis - - for output in model_desc.outputs_: - symbolic_axis = {} - for i, axis in enumerate(output.shape_): - if isinstance(axis, str): - symbolic_axis[i] = axis - if len(symbolic_axis): - dynamic_axes[output.name_] = symbolic_axis - - input_names = [input.name_ for input in model_desc.inputs_] - output_names = [output.name_ for output in model_desc.outputs_] - - if isinstance(inputs, torch.Tensor): - inputs = [inputs] - if isinstance(inputs, dict): - sample_inputs = [inputs[k.name_].to(device=device) for k in model_desc.inputs_] - elif isinstance(inputs, (list, tuple)): - sample_inputs = [input.to(device=device) for i, input in enumerate(inputs) if i < len(model_desc.inputs_)] - else: - raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.") - - # pytorch onnx exporter/trace does not try to match argument names. - # e.g. for models with optional inputs, it requires all inputs be present. - # this is a problem because the model graph depends on inputs provided. - model = wrap_for_input_match(model, loss_fn, input_names) - - model.eval() - with torch.no_grad(): - import copy - - # Deepcopy inputs, since input values may change after model run. - sample_inputs_copy = copy.deepcopy(sample_inputs) - try: - # Deepcopy model, in case model is stateful and changes after model run. - model_copy = copy.deepcopy(model) - except Exception: - model_copy = model - warnings.warn( - "This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX." - " Compute will continue, but unexpected results may occur!" - ) - - sample_outputs = model_copy(*sample_inputs_copy) - if isinstance(sample_outputs, torch.Tensor): - sample_outputs = [sample_outputs] - for sample_output, output_desc in zip(sample_outputs, model_desc.outputs_): - output_desc.dtype_ = sample_output.dtype - model.train() - - f = io.BytesIO() - - # Other export options to use(this is for backward compatibility). - other_export_options = {} - other_export_options["training"] = True - - # This option was added after 1.4 release. - if LooseVersion(torch.__version__) > LooseVersion("1.4.0") and LooseVersion(torch.__version__) < LooseVersion( - "1.10.0" - ): - other_export_options["enable_onnx_checker"] = False - # This option was added after 1.6 release. - if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): - other_export_options["training"] = torch.onnx.TrainingMode.TRAINING - - # Deepcopy inputs, since input values may change after model run. - import copy - - sample_inputs_copy = copy.deepcopy(sample_inputs) - - # Enable contrib ops export from PyTorch - from onnxruntime.tools import pytorch_export_contrib_ops - - pytorch_export_contrib_ops.register() - - torch.onnx._export( - model, - tuple(sample_inputs_copy), - f, - input_names=input_names, - output_names=output_names, - opset_version=opset_version, - dynamic_axes=dynamic_axes, - do_constant_folding=False, - **other_export_options, - ) - - onnx_model = onnx.load_model_from_string(f.getvalue()) - - # Remove 'model_.' prefix introduced by model wrapper for initializers. - if isinstance(model, (WrapModel, model_loss_cls)): - replace_name_dict = {} - for n in onnx_model.graph.initializer: - if n.name.startswith("model_."): - replace_name_dict[n.name] = n.name[len("model_.") :] - n.name = replace_name_dict[n.name] - for n in onnx_model.graph.node: - for i, name in enumerate(n.input): - if name in replace_name_dict: - n.input[i] = replace_name_dict[name] - - return onnx_model - - -def create_ort_training_session_with_optimizer( - model, - device, - training_optimizer_name, - lr_params_feed_name, - map_optimizer_attributes, - world_rank=-1, - world_size=1, - gradient_accumulation_steps=1, - bind_parameters=False, - use_mixed_precision=False, - allreduce_post_accumulation=False, - deepspeed_zero_stage=0, - enable_grad_norm_clip=True, - frozen_weights=[], # noqa: B006 - opset_version=DEFAULT_OPSET_VERSION, - use_deterministic_compute=False, - use_memory_efficient_gradient=False, - enable_adasum=False, - optimized_model_filepath="", -): - output_name = model.graph.output[0].name - ort_parameters = ort.TrainingParameters() - ort_parameters.loss_output_name = output_name - ort_parameters.use_mixed_precision = use_mixed_precision - ort_parameters.world_rank = world_rank - ort_parameters.world_size = world_size - ort_parameters.gradient_accumulation_steps = gradient_accumulation_steps - ort_parameters.allreduce_post_accumulation = allreduce_post_accumulation - ort_parameters.deepspeed_zero_stage = deepspeed_zero_stage - ort_parameters.enable_grad_norm_clip = enable_grad_norm_clip - ort_parameters.set_gradients_as_graph_outputs = False - ort_parameters.use_memory_efficient_gradient = use_memory_efficient_gradient - ort_parameters.enable_adasum = enable_adasum - output_types = {} - for output in model.graph.output: - output_types[output.name] = output.type.tensor_type - - # pybind does not allow to add directly to ort_parameters.weights_to_train. - # Have to work around by using a temporary weights_to_train. - torch_params = {} - optimizer_attributes_map = {} - optimizer_int_attributes_map = {} - - unused_frozen_weights = [n for n in frozen_weights if n not in [i.name for i in model.graph.initializer]] - if unused_frozen_weights: - raise RuntimeError(f"{unused_frozen_weights} in frozen_weights not found in model weights.") - - weights_to_train = set() - for initializer in model.graph.initializer: - if initializer.name in frozen_weights: - continue - weights_to_train.add(initializer.name) - if map_optimizer_attributes is not None: - attributes = map_optimizer_attributes(initializer.name) - optimizer_attributes_map[initializer.name] = {} - optimizer_int_attributes_map[initializer.name] = {} - for k, v in attributes.items(): - if isinstance(v, float): - optimizer_attributes_map[initializer.name][k] = v - elif isinstance(v, int): - optimizer_int_attributes_map[initializer.name][k] = v - else: - raise ValueError("Optimizer attributes must be either float or int.") - else: - optimizer_attributes_map[initializer.name] = {} - optimizer_int_attributes_map[initializer.name] = {} - - if bind_parameters: - for initializer in model.graph.initializer: - torch_tensor = torch.nn.Parameter(torch.as_tensor(numpy_helper.to_array(initializer), device=device)) - delete_input_with_name(model.graph.input, initializer.name) - model.graph.input.extend( - [helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims)] - ) - torch_params[initializer.name] = torch_tensor - - del model.graph.initializer[:] - - ort_parameters.weights_to_train = weights_to_train - ort_parameters.training_optimizer_name = training_optimizer_name - ort_parameters.lr_params_feed_name = lr_params_feed_name - ort_parameters.optimizer_attributes_map = optimizer_attributes_map - ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map - - sessionOptions = ort.SessionOptions() # noqa: N806 - sessionOptions.use_deterministic_compute = use_deterministic_compute - if len(optimized_model_filepath) > 0: - sessionOptions.optimized_model_filepath = optimized_model_filepath - session = ort.TrainingSession(model.SerializeToString(), ort_parameters, sessionOptions) - train_io_binding = session.io_binding() - eval_io_binding = session.io_binding() - - if bind_parameters: - for param in torch_params: - torch_tensor = torch_params[param] - - train_io_binding.bind_input( - param, - torch_tensor.device.type, - get_device_index(torch_tensor.device), - dtype_torch_to_numpy(torch_params[param].dtype), - list(torch_tensor.size()), - torch_tensor.data_ptr(), - ) - eval_io_binding.bind_input( - param, - torch_tensor.device.type, - get_device_index(torch_tensor.device), - dtype_torch_to_numpy(torch_params[param].dtype), - list(torch_tensor.size()), - torch_tensor.data_ptr(), - ) - - return session, train_io_binding, eval_io_binding, output_name, torch_params, output_types - - -def save_checkpoint( - model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", checkpoint_state_dict=None, include_optimizer_state=True -): - if checkpoint_state_dict is None: - checkpoint_state_dict = {"model": model.state_dict(include_optimizer_state)} - else: - checkpoint_state_dict.update({"model": model.state_dict(include_optimizer_state)}) - - assert os.path.exists(checkpoint_dir), f"ERROR: Checkpoint directory doesn't exist: {checkpoint_dir}" - - checkpoint_name = get_checkpoint_name( - checkpoint_prefix, model.deepspeed_zero_stage_, model.world_rank, model.world_size - ) - checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) - - if os.path.exists(checkpoint_file): - warnings.warn(f"{checkpoint_file} already exists, overwriting.") - - torch.save(checkpoint_state_dict, checkpoint_file) - - -def _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict): - checkpoint_name = get_checkpoint_name(checkpoint_prefix, is_partitioned, model.world_rank, model.world_size) - checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) - - if is_partitioned: - assert_msg = ( - f"Couldn't find checkpoint file {checkpoint_file}." - "Optimizer partitioning is enabled using ZeRO. Please make sure that the " - f"checkpoint file exists for rank {model.world_rank} of {model.world_size}." - ) - else: - assert_msg = f"Couldn't find checkpoint file {checkpoint_file}." - - assert os.path.exists(checkpoint_file), assert_msg - - checkpoint_state = torch.load(checkpoint_file, map_location="cpu") - - model.load_state_dict(checkpoint_state["model"], strict=strict) - del checkpoint_state["model"] - return checkpoint_state - - -def _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict): - checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix) - - ckpt_agg = CombineZeroCheckpoint(checkpoint_files) - aggregate_state_dict = ckpt_agg.aggregate_checkpoints() - - model.load_state_dict(aggregate_state_dict, strict=strict) - - # aggregate other keys in the state_dict. - # Values will be overwritten for matching keys among workers - all_checkpoint_states = {} - for checkpoint_file in checkpoint_files: - checkpoint_state = torch.load(checkpoint_file, map_location="cpu") - del checkpoint_state["model"] - all_checkpoint_states.update(checkpoint_state) - return all_checkpoint_states - - -def load_checkpoint(model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False): - checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix) - is_partitioned = False - if len(checkpoint_files) > 1: - warnings.warn( - f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." - "Attempting to load ZeRO checkpoint." - ) - is_partitioned = True - if (not model.deepspeed_zero_stage_) and is_partitioned: - return _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict) - else: - return _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict) - - -class ORTTrainer: - def __init__( - self, - model, - loss_fn, - model_desc, - training_optimizer_name, - map_optimizer_attributes, - learning_rate_description, - device, - gradient_accumulation_steps=1, - world_rank=0, - world_size=1, - use_mixed_precision=False, - allreduce_post_accumulation=False, - global_step=0, - get_lr_this_step=None, - loss_scaler=None, - deepspeed_zero_stage=0, - enable_grad_norm_clip=True, - frozen_weights=[], # noqa: B006 - _opset_version=DEFAULT_OPSET_VERSION, - _enable_internal_postprocess=True, - _extra_postprocess=None, - _use_deterministic_compute=False, - use_memory_efficient_gradient=False, - run_symbolic_shape_infer=False, - enable_adasum=False, - optimized_model_filepath="", - ): - super().__init__() - """ - Initialize ORTTrainer. - - Args: - - model: one of - - a PyTorch model (class that inherits from torch.nn.Module) - - a combined PyTorch model and loss function. - Inputs to this combined PyTorch model are a concatenation of the - model's input and the loss function's label input. - Outputs are a concatenation of the loss function's output and the - model's output. - - a combined ONNX model and loss function. - loss_fn: one of - - a PyTorch loss function if 'model' is a PyTorch model. A loss - function takes two inputs (prediction, label) and outputs a loss - tensor. - - None if model is already combined with a loss function. - model_desc: Specify input/output shapes, types, and names. - Must be consistent with the training model. - training_optimizer_name: one of - - 'SGDOptimizer' - - 'AdamOptimizer' - - 'LambOptimizer' - map_optimizer_attributes: for optimizers with weight-dependent - parameters. A callable that maps weight name to a set of optimization - parameters. - Defaults to None. - learning_rate_description: the name, shape and type of the learning - rate in form of IODescription(Learning_Rate_Name, [1,], torch.float32). - Because learning_rate is an input to the training model, - Learning_Rate_Name must be specified so that there is no name conflict - within the model. - device: device to store tensors (e.g. 'cpu', 'cuda', 'cuda:'). - gradient_accumulation_steps: number of training steps to accumulate - gradients before averaging and applying them. - Defaults to 1. - world_rank: rank id used for distributed training. - Defaults to 0. - world_size: number of ranks participating in distributed training. - Defaults to 1. - use_mixed_precision: flag to enable mixed precision (aka fp16). - Defaults to False. - allreduce_post_accumulation: controls whether overlaping gradient - computation is applied with allreduce. - Defaults to False. - global_step: training step that is used as input to 'get_lr_this_step'. - Defaults to 0. - get_lr_this_step: functor used as learning rate scheduler. - It uses 'global_step' as input. - Defaults to None. - loss_scaler: updates loss scale automatically when 'use_mixed_precision' - is specified. - Defaults to None. - deepspeed_zero_stage: controls whether to partition state using the DeepSpeed ZeRO technique. Stages 0 and 1 are supported. - Defaults to 0 (disabled). - enable_grad_norm_clip: enables gradient norm clipping. - Defaults to True. - frozen_weights: list of model parameters to be frozen (not trained). - Defaults to []. - _enable_internal_postprocess: whether to run or not the internal postprocesses. - Defaults to True - _extra_postprocess: a callable to postprocess the ONNX model that is converted from PyTorch. - Defaults to None - use_memory_efficient_gradient: use memory aware gradient builder. - Defaults to False - run_symbolic_shape_infer: run symbolic shape inference - Defaults to False - optimized_model_filepath: path to output the optimized training graph. - Defaults to "" (no output). - """ - warnings.warn( - "ORTTrainer is deprecated and will be removed in ort release 1.14. Please use ORTModule instead.", - FutureWarning, - ) - warnings.warn( - "DISCLAIMER: This is an early version of an experimental training API and it is subject to change. DO NOT create production applications with it" - ) - self.is_train = True - - self.torch_model_ = None - self.onnx_model_ = None - self._enable_internal_postprocess = _enable_internal_postprocess - self._extra_postprocess = _extra_postprocess - - if isinstance(model, torch.nn.Module): - self.torch_model_ = model - self.loss_fn_ = loss_fn - self._torch_state_dict_keys = list(model.state_dict().keys()) - else: - self._torch_state_dict_keys = [] - self.onnx_model_ = model - if loss_fn is not None: - warnings.warn("loss_fn is not used when creating ORTTrainer because an ONNX model is provided.") - # TODO: accept loss_fn as an onnx model. build self.onnx_model_ with model and loss_fn - self.loss_fn_ = None - - if self._enable_internal_postprocess: - postprocess.run_postprocess(self.onnx_model_) - - if self._extra_postprocess: - self._extra_postprocess(self.onnx_model_) - - self.model_desc_ = model_desc - self.input_desc_with_lr = [*self.model_desc_.inputs_, learning_rate_description] - - self.world_rank = world_rank - self.world_size = world_size - self.use_mixed_precision = use_mixed_precision - - self.session = None - self.device_ = device - self.gradient_accumulation_steps = gradient_accumulation_steps - # we use self.current_step to count calls to train_step. It is used for gradient accumulation. - # gradients are being accumulated when self.current_step is not divisible by gradient_accumulation_steps. - # gradients are updated when self.current_step is divisible by gradient_accumulation_steps. - self.current_step = 0 - - # we use self.global_step_ to count optimizations being performed. - # it is used to calculate learning rate if self.get_lr_this_step_ is provided. - self.global_step_ = global_step - self.get_lr_this_step_ = get_lr_this_step - self.loss_scaler_ = loss_scaler - - if self.get_lr_this_step_ is not None or self.loss_scaler_ is not None: - warnings.warn("It is experimental to use learning rate scheduler and loss scaler inside ORTTrainer.") - self.training_optimizer_name_ = training_optimizer_name - self.learning_rate_description_ = learning_rate_description - self.map_optimizer_attributes_ = map_optimizer_attributes - self.allreduce_post_accumulation_ = allreduce_post_accumulation - self.deepspeed_zero_stage_ = deepspeed_zero_stage - self.enable_grad_norm_clip_ = enable_grad_norm_clip - self.frozen_weights_ = frozen_weights - self.opset_version_ = _opset_version - self.state_dict_ = None - self._use_deterministic_compute = _use_deterministic_compute - self.use_memory_efficient_gradient = use_memory_efficient_gradient - self.run_symbolic_shape_infer = run_symbolic_shape_infer - self.enable_adasum = enable_adasum - self.optimized_model_filepath = optimized_model_filepath - - # use this special string to workaround a corner case that external loss_scale is passed into train_step as kwargs. - # see prepare_input_and_fetches for more details. - self.loss_scale_input_name = "default_loss_scale_input_name" - - self._init_session() - - def _init_session(self): - if self.onnx_model_ is None: - return - - self._verify_fully_optimized_model(self.onnx_model_) - - if self.run_symbolic_shape_infer: - self.onnx_model_ = SymbolicShapeInference.infer_shapes( - self.onnx_model_, auto_merge=True, guess_output_rank=True - ) - - # old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error. - # for example, load_state_dict will be called before returing the function, and it calls _init_session again - del self.session - ( - self.session, - self.train_io_binding, - self.eval_io_binding, - self.output_name, - _, - self.output_types, - ) = create_ort_training_session_with_optimizer( - self.onnx_model_, - self.device_, - self.training_optimizer_name_, - self.learning_rate_description_.name_, - self.map_optimizer_attributes_, - self.world_rank, - self.world_size, - self.gradient_accumulation_steps, - bind_parameters=False, - use_mixed_precision=self.use_mixed_precision, - allreduce_post_accumulation=self.allreduce_post_accumulation_, - deepspeed_zero_stage=self.deepspeed_zero_stage_, - enable_grad_norm_clip=self.enable_grad_norm_clip_, - frozen_weights=self.frozen_weights_, - opset_version=self.opset_version_, - use_deterministic_compute=self._use_deterministic_compute, - use_memory_efficient_gradient=self.use_memory_efficient_gradient, - enable_adasum=self.enable_adasum, - optimized_model_filepath=self.optimized_model_filepath, - ) - - self.loss_scale_input_name = self.session.loss_scale_input_name - - if self.use_mixed_precision: - self.input_desc_with_lr_and_loss_scale = [ - *self.input_desc_with_lr, - IODescription(self.loss_scale_input_name, [], torch.float32), - ] - - # ORT backend has modified model output dtype from float32 to float16. - for o_desc in self.model_desc_.outputs_: - if ( - self.use_mixed_precision - and o_desc.dtype_ == torch.float32 - and not self.session.is_output_fp32_node(o_desc.name_) - ): - o_desc.eval_dtype_ = torch.float16 - else: - o_desc.eval_dtype_ = o_desc.dtype_ - - # gradient accumulation buffers are connected to a single node with a boolean, dimension 1 tensor output. - # add a matching output to drive gradient accumulation. - if self.gradient_accumulation_steps > 1: - self.output_desc_with_group_accumulated_gradients = [ - *self.model_desc_.outputs_, - IODescription(get_group_accumulated_gradients_output_node_arg_name(self.session), [1], torch.bool), - ] - - if self.use_mixed_precision: - # when ready to use accumulated gradient with mixed precision, we need to fetch all_infinite to determine - # if the gradient is usable. - self.output_desc_with_all_fp_16_or_fp32_gradients_finite = [ - *self.model_desc_.outputs_, - IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool), - ] - - if self.state_dict_: - self.load_state_dict(self.state_dict_, self.strict_) - self.state_dict_ = None - - def _init_onnx_model(self, inputs): - if self.onnx_model_ is not None: - return - - if self.torch_model_ is not None: - # NOTE: pt model is moved to cpu to conserve gpu memory. - self.torch_model_.cpu() - # torch buffers created using 'register_buffer' are not meant to be trainable. - torch_buffers = list(dict(self.torch_model_.named_buffers()).keys()) - self.frozen_weights_ = self.frozen_weights_ + torch_buffers - self.onnx_model_ = convert_model_loss_fn_to_onnx( - self.torch_model_, - self.loss_fn_, - self.model_desc_, - torch.device("cpu"), - inputs, - opset_version=self.opset_version_, - ) - - if self._enable_internal_postprocess: - postprocess.run_postprocess(self.onnx_model_) - - if self._extra_postprocess: - self._extra_postprocess(self.onnx_model_) - - self._init_session() - - def train(self): - self.is_train = True - - def eval(self): - self.is_train = False - - def _update_onnx_model_initializers(self, state_tensors): - # replace the initializers with new value - new_weights = [] - replace_indices = [] - for i, w in enumerate(self.onnx_model_.graph.initializer): - if w.name in state_tensors: - new_weights.append(numpy_helper.from_array(state_tensors[w.name], w.name)) - replace_indices.append(i) - replace_indices.sort(reverse=True) - for w_i in replace_indices: - del self.onnx_model_.graph.initializer[w_i] - self.onnx_model_.graph.initializer.extend(new_weights) - - def state_dict(self, include_optimizer_state=True): - if not self.session: - warnings.warn( - "ONNXRuntime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling state_dict()." - ) - return {} - - # extract trained weights - session_state = self.session.get_state() - torch_state = {} - for name in session_state: - torch_state[name] = torch.from_numpy(session_state[name]) - - # extract untrained weights and buffer - for n in self.onnx_model_.graph.initializer: - if n.name not in torch_state: - torch_state[n.name] = torch.from_numpy(numpy_helper.to_array(n)) - - # Need to remove redundant initializers and name suffices to map back to original torch state names - if not include_optimizer_state and self._torch_state_dict_keys: - return {key: torch_state[key] for key in self._torch_state_dict_keys if key in torch_state} - return torch_state - - def load_state_dict(self, state_dict, strict=False): - # Note: It may happen ONNX model has not yet been initialized - # In this case we cache a reference to desired state and delay the restore until after initialization - # Unexpected behavior will result if the user changes the reference before initialization - if not self.session: - self.state_dict_ = state_dict - self.strict_ = strict - return - - # update onnx model from loaded state dict - cur_initializers_names = [n.name for n in self.onnx_model_.graph.initializer] - new_initializers = {} - - for name in state_dict: - if name in cur_initializers_names: - new_initializers[name] = state_dict[name].numpy() - elif strict: - raise RuntimeError(f"Checkpoint tensor: {name} is not present in the model.") - - self._update_onnx_model_initializers(new_initializers) - - # create new session based on updated onnx model - self.state_dict_ = None - self._init_session() - - # load training state - session_state = {name: state_dict[name].numpy() for name in state_dict} - self.session.load_state(session_state, strict) - - def save_as_onnx(self, path): - if not self.session: - warnings.warn( - "ONNXRuntime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling save_as_onnx()." - ) - return - state_tensors = self.session.get_state() - self._update_onnx_model_initializers(state_tensors) - - with open(path, "wb") as f: - f.write(self.onnx_model_.SerializeToString()) - - def _prepare_input_and_fetches( - self, input_desc_with_, internal_learning_rate, internal_loss_scale, *args, **kwargs - ): - fetches = None - if type(args) == tuple and len(args) == 1 and type(args[0]) == list: # noqa: E721 - input = tuple(args[0]) - else: - input = args - - for input_desc in input_desc_with_: - if input_desc.name_ in kwargs: - input = (*input, kwargs[input_desc.name_]) - if internal_learning_rate is not None: - input = (*input, internal_learning_rate) - if internal_loss_scale is not None: - input = (*input, internal_loss_scale) - elif self.use_mixed_precision: - # loss_scale input name is needed to call train_step, for example: - # kwargs[model.loss_scale_input_name] = loss_scale - # outputs = model.train_step(*args, **kwargs) - # However, when first time train_step is called model.loss_scale_input_name is not set. - # To workaround this problem, we use the special name 'default_loss_scale_input_name' to indicate - # the loss_scale. - if "default_loss_scale_input_name" in kwargs: - input = (*input, kwargs["default_loss_scale_input_name"]) - - fetches = None - if "fetches" in kwargs: - fetches = kwargs["fetches"] - - return input, fetches - - def train_step(self, *args, **kwargs): - """ - inputs: model inputs, labels, learning rate, and, if in mixed_precision mode, loss_scale. - outputs: if fetches is not provided, outputs are loss and - (if in mixed mode and is finishing gradient accumulation) all_finite. - if fetches is provided, outputs contains these requested with fetches. - fetches: names of requested outputs - """ - - # inputs to the ONNX model includes inputs to the original PyTorch model - # plus learning rate and loss_scale if self.use_mixed_precision is True. - # 1. when there are internal learning_rate and loss_scale (in fp16 cases) generators, - # *args and **kwargs together contain ONLY and COMPLETE inputs to the PyTorch model. - # In this case, changes to the training script is minimized. - # 2. without internal learning rate and loss scale (in fp16 cases) generators, - # *args and **kwargs passed in from the training script shall contains - # inputs to the PyTorch model plus learning_rate and loss_scale. - # it optionally contains the fetches. - # localized arguments (*args) contains inputs to the ONNX model. - # named arguments can contain both inputs, learning_rate and loss_scale, and the fetches - - learning_rate, loss_scale = None, None - if self.get_lr_this_step_ is not None: - # $args, **kwargs contains inputs to the pytorch model - lr_this_step = self.get_lr_this_step_(self.global_step_) - learning_rate = torch.tensor([lr_this_step]) - if self.loss_scaler_ is not None and self.use_mixed_precision: - loss_scale = torch.tensor([self.loss_scaler_.loss_scale_]) - - if self.onnx_model_ is None: - sample_input, _ = self._prepare_input_and_fetches(self.model_desc_.inputs_, None, None, *args, **kwargs) - self._init_onnx_model(sample_input) - - if self.use_mixed_precision: - input, fetches = self._prepare_input_and_fetches( - self.input_desc_with_lr_and_loss_scale, learning_rate, loss_scale, *args, **kwargs - ) - assert len(self.input_desc_with_lr_and_loss_scale) == len(input) - input_descs = self.input_desc_with_lr_and_loss_scale - else: - input, fetches = self._prepare_input_and_fetches( - self.input_desc_with_lr, learning_rate, loss_scale, *args, **kwargs - ) - assert len(self.input_desc_with_lr) == len(input) - input_descs = self.input_desc_with_lr - - self.current_step += 1 - - # handle gradient accumulation in fully optimized mode - run_options = None - has_if_all_finite = False - if fetches: - output_desc = [output for fetch in fetches for output in self.model_desc_.outputs_ if output.name_ == fetch] - elif self.current_step % self.gradient_accumulation_steps != 0: - run_options = ort.RunOptions() - run_options.only_execute_path_to_fetches = True - output_desc = self.output_desc_with_group_accumulated_gradients - elif self.use_mixed_precision: - has_if_all_finite = True - output_desc = self.output_desc_with_all_fp_16_or_fp32_gradients_finite - else: - output_desc = self.model_desc_.outputs_ - - if not isinstance(input, (list, tuple)): - input = (input,) - - session_run_results = ort_training_session_run_helper( - self.session, self.train_io_binding, input, input_descs, output_desc, self.device_, run_options - ) - - if has_if_all_finite: - # After session run with all_fp32_gradients_finite, we need to clear the iobinding's output state. - # Otherwise next run with only_execute_path_to_fetches will lead to gradient all reduce - # because all_fp32_gradients_finite is still in the feed. - self.train_io_binding.clear_binding_outputs() - all_finite = session_run_results[self.output_desc_with_all_fp_16_or_fp32_gradients_finite[-1].name_] - if self.loss_scaler_ is not None: - self.loss_scaler_.update_loss_scale(all_finite) - if all_finite: - # optimization has done, increase self.global_step_ - self.global_step_ = self.global_step_ + 1 - elif self.current_step % self.gradient_accumulation_steps == 0: - # optimization has done, increase self.global_step_ - self.global_step_ = self.global_step_ + 1 - - if fetches is not None: - results = [session_run_results[fetch] for fetch in fetches] - elif has_if_all_finite and self.loss_scaler_ is None: - # return descripted outputs plus the all_finite flag so that the training script can handle loss scaling. - results = [ - session_run_results[output_desc.name_] - for output_desc in self.output_desc_with_all_fp_16_or_fp32_gradients_finite - ] - else: - results = [session_run_results[output_desc.name_] for output_desc in self.model_desc_.outputs_] - return results[0] if len(results) == 1 else results - - def __call__(self, *args, **kwargs): - if self.is_train: - return self.train_step(*args, **kwargs) - else: - return self.eval_step(*args, **kwargs) - - def eval_step(self, *args, **kwargs): - """ - inputs: model inputs and/or labels. - outputs: if 'fetches' is not provided, outputs are loss and - (if in mixed mode and is finishing gradient accumulation) all_finite. - if fetches is provided, outputs contains these requested with fetches. - fetches: names of requested outputs - """ - - # with model_loss_cls, the last input is label, first output is loss - input, fetches = self._prepare_input_and_fetches(self.model_desc_.inputs_, None, None, *args, **kwargs) - - if self.onnx_model_ is None: - if self.torch_model_ is not None: - self._init_onnx_model(input) - else: - raise RuntimeError( - "Model is unintialized. Please ensure a valid ONNX model or PyTorch model is provided to this Trainer." - ) - - input_desc = self.model_desc_.inputs_[0 : len(input)] - if fetches is None: - output_desc = self.model_desc_.outputs_ - else: - output_desc = [output for fetch in fetches for output in self.model_desc_.outputs_ if output.name_ == fetch] - - if not isinstance(input, (list, tuple)): - input = (input,) - - run_options = ort.RunOptions() - run_options.only_execute_path_to_fetches = True - run_options.training_mode = False - - session_run_results = ort_training_session_run_helper( - self.session, self.eval_io_binding, input, input_desc, output_desc, self.device_, run_options - ) - - if len(session_run_results) == 1: - return session_run_results[next(iter(session_run_results.keys()))] - else: - return [session_run_results[output_desc.name_] for output_desc in output_desc] - - def _verify_fully_optimized_model(self, model): - assert len(model.graph.output) > 0 - # model's first output must be the loss tensor - if model.graph.output[0].type.tensor_type.elem_type not in { - onnx.TensorProto.FLOAT, - onnx.TensorProto.FLOAT16, - onnx.TensorProto.DOUBLE, - onnx.TensorProto.COMPLEX64, - onnx.TensorProto.COMPLEX128, - onnx.TensorProto.BFLOAT16, - onnx.TensorProto.FLOAT8E4M3FN, - onnx.TensorProto.FLOAT8E4M3FNUZ, - onnx.TensorProto.FLOAT8E5M2, - onnx.TensorProto.FLOAT8E5M2FNUZ, - }: - raise RuntimeError( - "the first output of a model to run with fully optimized ORT backend must be float types." - ) - if len(model.graph.output[0].type.tensor_type.shape.dim) != 0: - raise RuntimeError( - "the first output of a model to run with fully optimized ORT backend assumed to be loss and must be a scalar." - ) - - -class LossScaler: - def __init__( - self, - loss_scale_input_name, - is_dynamic_scale, - loss_scale=float(1 << 16), - up_scale_window=2000, - min_loss_scale=1.0, - max_loss_scale=float(1 << 24), - ): - super().__init__() - self.loss_scale_input_name_ = loss_scale_input_name - self.is_dynamic_scale_ = is_dynamic_scale - self.initial_loss_scale_ = loss_scale - self.up_scale_window_ = up_scale_window - self.min_loss_scale_ = min_loss_scale - self.max_loss_scale_ = max_loss_scale - self.loss_scale_ = loss_scale - self.stable_steps_ = 0 - - def update_loss_scale(self, is_all_finite): - if not self.is_dynamic_scale_: - return - - if is_all_finite: - self.stable_steps_ += 1 - - if self.stable_steps_ >= self.up_scale_window_: - self.loss_scale_ = min(self.max_loss_scale_, self.loss_scale_ * 2) - self.stable_steps_ = 0 - else: - self.loss_scale_ = max(self.min_loss_scale_, self.loss_scale_ / 2) - self.stable_steps_ = 0 - - def reset(self): - self.loss_scale_ = self.initial_loss_scale_ - self.stable_steps_ = 0 diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index a08e8bee99cee..bb1cb4bbd32f7 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -18,7 +18,6 @@ #include "core/session/environment.h" #include "core/session/custom_ops.h" #include "core/dlpack/dlpack_converter.h" -#include "orttraining/core/session/training_session.h" #include "orttraining/core/agent/training_agent.h" #include "orttraining/core/graph/gradient_config.h" #include "orttraining/core/graph/optimizer_config.h" @@ -113,14 +112,11 @@ struct TrainingParameters { std::unordered_set weights_to_train; std::unordered_set weights_not_to_train; - onnxruntime::training::TrainingSession::ImmutableWeights immutable_weights; - // optimizer std::string training_optimizer_name; std::string lr_params_feed_name = "Learning_Rate"; std::unordered_map> optimizer_attributes_map; std::unordered_map> optimizer_int_attributes_map; - onnxruntime::training::TrainingSession::OptimizerState optimizer_initial_state; std::unordered_map> sliced_schema; std::unordered_map sliced_axes; std::vector sliced_tensor_names; @@ -206,185 +202,6 @@ struct PyGradientGraphBuilderContext { local_registries_(local_registries) {} }; -// TODO: this method does not handle parallel optimization. -TrainingConfigurationResult ConfigureSessionForTraining( - training::PipelineTrainingSession* sess, TrainingParameters& parameters) { - // TODO tix, refactor the mpi related code to populate all fields correctly by default. - ORT_ENFORCE(parameters.data_parallel_size <= parameters.world_size, "data_parallel_size: ", parameters.data_parallel_size, ", world_size: ", parameters.world_size); - ORT_ENFORCE(parameters.horizontal_parallel_size <= parameters.world_size, "horizontal_parallel_size: ", parameters.horizontal_parallel_size, ", world_size: ", parameters.world_size); - ORT_ENFORCE(parameters.pipeline_parallel_size <= parameters.world_size, "pipeline_parallel_size: ", parameters.pipeline_parallel_size, ", world_size: ", parameters.world_size); - - // When DxHxP != the total number of ranks, we try adjusting D so that DxHxP == the total number of ranks. - if (parameters.world_size != parameters.data_parallel_size * parameters.horizontal_parallel_size * parameters.pipeline_parallel_size) { - ORT_ENFORCE(parameters.world_size % parameters.horizontal_parallel_size * parameters.pipeline_parallel_size == 0, - "D, H, P sizes are incorrect. To enable automatic correction, total number of ranks must be a divisible by HxP."); - - const auto new_data_parallel_size = parameters.world_size / (parameters.horizontal_parallel_size * parameters.pipeline_parallel_size); - parameters.data_parallel_size = new_data_parallel_size; - - const std::string msg = "Cannot distribute " + std::to_string(parameters.world_size) + " ranks for distributed computation with D=" + std::to_string(parameters.data_parallel_size) + - ", H=" + std::to_string(parameters.horizontal_parallel_size) + ", P=" + std::to_string(parameters.pipeline_parallel_size) + ", so D is automatically changed to " + std::to_string(new_data_parallel_size); - LOGS(*(sess->GetLogger()), WARNING) << msg; - } - - training::PipelineTrainingSession::TrainingConfiguration config{}; - config.weight_names_to_train = parameters.weights_to_train; - config.weight_names_to_not_train = parameters.weights_not_to_train; - config.immutable_weights = parameters.immutable_weights; - config.gradient_accumulation_steps = parameters.gradient_accumulation_steps; - - config.distributed_config.world_rank = parameters.world_rank; - config.distributed_config.world_size = parameters.world_size; - config.distributed_config.local_rank = parameters.local_rank; - config.distributed_config.local_size = parameters.local_size; - config.distributed_config.data_parallel_size = parameters.data_parallel_size; - config.distributed_config.horizontal_parallel_size = parameters.horizontal_parallel_size; - config.distributed_config.pipeline_parallel_size = parameters.pipeline_parallel_size; - config.distributed_config.num_pipeline_micro_batches = parameters.num_pipeline_micro_batches; - config.distributed_config.sliced_schema = parameters.sliced_schema; - config.distributed_config.sliced_axes = parameters.sliced_axes; - config.distributed_config.sliced_tensor_names = parameters.sliced_tensor_names; - - if (parameters.use_mixed_precision) { - training::PipelineTrainingSession::TrainingConfiguration::MixedPrecisionConfiguration mp{}; - mp.use_mixed_precision_initializers = true; - - config.mixed_precision_config = mp; - } - - if (config.distributed_config.pipeline_parallel_size > 1) { - training::PipelineTrainingSession::TrainingConfiguration::PipelineConfiguration pipeline_config; - - // Currently don't support auto-partition. User needs to pass in cut information for pipeline - pipeline_config.do_partition = true; - assert(!parameters.pipeline_cut_info_string.empty()); - - auto process_with_delimiter = [](std::string& input_str, const std::string& delimiter) { - std::vector result; - size_t pos = 0; - while ((pos = input_str.find(delimiter)) != std::string::npos) { - std::string token = input_str.substr(0, pos); - result.emplace_back(token); - input_str.erase(0, pos + delimiter.length()); - } - // push the last split of substring into result. - result.emplace_back(input_str); - return result; - }; - - auto process_cut_info = [&](std::string& cut_info_string) { - std::vector cut_list; - const std::string group_delimiter = ","; - const std::string edge_delimiter = ":"; - const std::string consumer_delimiter = "/"; - const std::string producer_consumer_delimiter = "-"; - - auto cut_info_groups = process_with_delimiter(cut_info_string, group_delimiter); - for (auto& cut_info_group : cut_info_groups) { - PipelineTrainingSession::TrainingConfiguration::CutInfo cut_info; - auto cut_edges = process_with_delimiter(cut_info_group, edge_delimiter); - for (auto& cut_edge : cut_edges) { - auto process_edge = process_with_delimiter(cut_edge, producer_consumer_delimiter); - if (process_edge.size() == 1) { - PipelineTrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0]}; - cut_info.emplace_back(edge); - } else { - ORT_ENFORCE(process_edge.size() == 2); - auto consumer_list = process_with_delimiter(process_edge[1], consumer_delimiter); - - PipelineTrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0], consumer_list}; - cut_info.emplace_back(edge); - } - } - cut_list.emplace_back(cut_info); - } - return cut_list; - }; - - pipeline_config.cut_list = process_cut_info(parameters.pipeline_cut_info_string); - config.pipeline_config = pipeline_config; - } - config.loss_name = parameters.loss_output_name; - - if (!parameters.training_optimizer_name.empty()) { - training::PipelineTrainingSession::TrainingConfiguration::OptimizerConfiguration opt{}; - opt.name = parameters.training_optimizer_name; - opt.learning_rate_input_name = parameters.lr_params_feed_name; - opt.weight_attributes_generator = [¶meters](const std::string& weight_name) { - const auto it = parameters.optimizer_attributes_map.find(weight_name); - ORT_ENFORCE( - it != parameters.optimizer_attributes_map.end(), - "Failed to find attribute map for weight ", weight_name); - return it->second; - }; - opt.weight_int_attributes_generator = [¶meters](const std::string& weight_name) { - const auto it = parameters.optimizer_int_attributes_map.find(weight_name); - ORT_ENFORCE( - it != parameters.optimizer_int_attributes_map.end(), - "Failed to find int attribute map for weight ", weight_name); - return it->second; - }; - opt.use_mixed_precision_moments = parameters.use_fp16_moments; - opt.do_all_reduce_in_mixed_precision_type = true; - // TODO: this mapping is temporary. - // For now, nccl allreduce kernel only implements for allreduce_post_accumulation - // hovorod allreduce kernel only implements for not allreduce_post_accumulation. - // eventually we will have one all reduce kernel and let opt to have - // an allreduce_post_accumulation option and remove the use_nccl option. - opt.use_nccl = parameters.allreduce_post_accumulation; - opt.deepspeed_zero = onnxruntime::training::ZeROConfig(parameters.deepspeed_zero_stage); - opt.enable_grad_norm_clip = parameters.enable_grad_norm_clip; - - // TODO reduction types - if (parameters.enable_adasum) { -#ifdef USE_CUDA - opt.adasum_reduction_type = training::AdasumReductionType::GpuHierarchicalReduction; -#else - opt.adasum_reduction_type = training::AdasumReductionType::CpuReduction; -#endif - } - - config.optimizer_config = opt; - } - - if (!parameters.optimizer_initial_state.empty()) { - config.init_optimizer_states = parameters.optimizer_initial_state; - } - - config.gradient_graph_config.use_memory_efficient_gradient = parameters.use_memory_efficient_gradient; - config.gradient_graph_config.set_gradients_as_graph_outputs = parameters.set_gradients_as_graph_outputs; - - config.graph_transformer_config.attn_dropout_recompute = parameters.attn_dropout_recompute; - config.graph_transformer_config.gelu_recompute = parameters.gelu_recompute; - config.graph_transformer_config.transformer_layer_recompute = parameters.transformer_layer_recompute; - config.graph_transformer_config.number_recompute_layers = parameters.number_recompute_layers; - config.graph_transformer_config.propagate_cast_ops_config.strategy = parameters.propagate_cast_ops_strategy; - config.graph_transformer_config.propagate_cast_ops_config.level = parameters.propagate_cast_ops_level; - config.graph_transformer_config.propagate_cast_ops_config.allow = parameters.propagate_cast_ops_allow; - - if (!parameters.model_after_graph_transforms_path.empty()) { - config.model_after_graph_transforms_path = ToPathString(parameters.model_after_graph_transforms_path); - } - if (!parameters.model_with_gradient_graph_path.empty()) { - config.model_with_gradient_graph_path = ToPathString(parameters.model_with_gradient_graph_path); - } - if (!parameters.model_with_training_graph_path.empty()) { - config.model_with_training_graph_path = ToPathString(parameters.model_with_training_graph_path); - } - - training::PipelineTrainingSession::TrainingConfigurationResult config_result{}; - - OrtPybindThrowIfError(sess->ConfigureForTraining(config, config_result)); - - TrainingConfigurationResult python_config_result{}; - if (config_result.mixed_precision_config_result.has_value()) { - const auto& mp_config_result = config_result.mixed_precision_config_result.value(); - python_config_result.loss_scale_input_name = mp_config_result.loss_scale_input_name; - } - - return python_config_result; -} - #if defined(USE_MPI) void CopyMPIContextToTrainingParameters(TrainingParameters& parameters, const logging::Logger* logger) { LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldRank(): " << MPIContext::GetInstance().GetWorldRank(); @@ -424,7 +241,7 @@ std::unordered_map> Con return py_tensor_state; } -void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn) { +void addObjectMethodsForTraining(py::module& m) { py::class_(m, "OrtValueCache") .def(py::init<>()) .def("insert", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name, OrtValue& value) { @@ -451,7 +268,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn py::class_ parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc"); parameters.def(py::init()) .def_readwrite("loss_output_name", &TrainingParameters::loss_output_name) - .def_readwrite("immutable_weights", &TrainingParameters::immutable_weights) .def_readwrite("weights_not_to_train", &TrainingParameters::weights_not_to_train) .def_readwrite("weights_to_train", &TrainingParameters::weights_to_train) .def_readwrite("sliced_tensor_names", &TrainingParameters::sliced_tensor_names) @@ -484,25 +300,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn .def_readwrite("data_parallel_size", &TrainingParameters::data_parallel_size) .def_readwrite("horizontal_parallel_size", &TrainingParameters::horizontal_parallel_size) .def_readwrite("pipeline_parallel_size", &TrainingParameters::pipeline_parallel_size) - .def("set_optimizer_initial_state", - [](TrainingParameters& parameters, const std::unordered_map>& py_state) -> void { - onnxruntime::training::TrainingSession::OptimizerState optim_state; - for (const auto& weight_it : py_state) { - auto state = weight_it.second; - NameMLValMap state_tensors; - for (auto& initializer : state) { - OrtValue ml_value; - - // InputDeflist is null because parameters havent been tied to session yet - // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) - CreateGenericMLValue(nullptr, GetAllocator(), "", initializer.second, &ml_value, true); - ThrowIfPyErrOccured(); - state_tensors.emplace(initializer.first, ml_value); - } - optim_state.emplace(weight_it.first, state_tensors); - } - parameters.optimizer_initial_state = optim_state; - }) .def_readwrite("model_after_graph_transforms_path", &TrainingParameters::model_after_graph_transforms_path) .def_readwrite("model_with_gradient_graph_path", &TrainingParameters::model_with_gradient_graph_path) .def_readwrite("model_with_training_graph_path", &TrainingParameters::model_with_training_graph_path) @@ -611,130 +408,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn }); #endif - py::class_ config_result(m, "TrainingConfigurationResult", "pbdoc(Configuration result for training.)pbdoc"); - config_result.def(py::init()) - .def_property_readonly("loss_scale_input_name", [](const TrainingConfigurationResult& result) -> py::object { - if (result.loss_scale_input_name.has_value()) { - return py::str{result.loss_scale_input_name.value()}; - } - return py::none(); - }); - - // Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user - struct PyTrainingSession : public PyInferenceSession { - PyTrainingSession(std::shared_ptr env, const PySessionOptions& so) - : PyInferenceSession(env, std::make_unique(so.value, *env)) { - } - ~PyTrainingSession() = default; - }; - - py::class_ training_session(m, "TrainingSession"); - training_session - .def(py::init([](const PySessionOptions& so) { - auto& training_env = GetTrainingEnv(); - return std::make_unique(training_env.GetORTEnv(), so); - })) - .def(py::init([]() { - auto& training_env = GetTrainingEnv(); - return std::make_unique(training_env.GetORTEnv(), GetDefaultCPUSessionOptions()); - })) - .def("finalize", [](py::object) { -#if defined(USE_MPI) -#ifdef _WIN32 - // https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices - // shutdown_mpi() is not called within MPIContext destructor because of DllMain's restriction - // call shutdown_mpi() here instead. - MPIContext::shutdown_mpi(); -#endif -#endif - }) - .def("load_model", [ep_registration_fn](PyTrainingSession* sess, const std::string& path, TrainingParameters& parameters, const std::vector& provider_types, const ProviderOptionsVector& provider_options) { - OrtPybindThrowIfError(sess->GetSessionHandle()->Load(path)); - -#if defined(USE_MPI) - bool use_nccl = parameters.allreduce_post_accumulation; - if (!use_nccl && parameters.world_size > 1) - CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger()); -#endif - const auto config_result = ConfigureSessionForTraining(static_cast(sess->GetSessionHandle()), parameters); - - ProviderOptionsVector merged_options; - ResolveExtraProviderOptions(provider_types, provider_options, merged_options); - - InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options); - - return config_result; - }) - .def("read_bytes", [ep_registration_fn](PyTrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters, const std::vector& provider_types, const ProviderOptionsVector& provider_options) { - std::istringstream buffer(serialized_model); - OrtPybindThrowIfError(sess->GetSessionHandle()->Load(buffer)); - -#if defined(USE_MPI) - bool use_nccl = parameters.allreduce_post_accumulation; - if (!use_nccl && parameters.world_size > 1) - CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger()); -#endif - const auto config_result = ConfigureSessionForTraining(static_cast(sess->GetSessionHandle()), parameters); - ProviderOptionsVector merged_options; - ResolveExtraProviderOptions(provider_types, provider_options, merged_options); - - InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options); - - return config_result; - }) - .def("get_state", [](PyTrainingSession* sess) { - NameMLValMap state_tensors; - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->GetStateTensors(state_tensors)); - auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager(); - // convert to numpy array - std::map rmap; - for (auto& kv : state_tensors) { - if (kv.second.IsTensor()) { - py::object obj; - const Tensor& rtensor = kv.second.Get(); - GetPyObjFromTensor(rtensor, obj, &data_transfer_manager); - rmap.insert({kv.first, obj}); - } else { - throw std::runtime_error("Non tensor type in session state tensors is not expected."); - } - } - return rmap; - }) - .def("get_model_state", [](PyTrainingSession* sess, bool include_mixed_precision_weights) { - std::unordered_map model_state_tensors; - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->GetModelState(model_state_tensors, include_mixed_precision_weights)); - auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager(); - return ConvertORTTensorMapToNumpy(model_state_tensors, data_transfer_manager); - }) - .def("get_optimizer_state", [](PyTrainingSession* sess) { - std::unordered_map opt_state_tensors; - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->GetOptimizerState(opt_state_tensors)); - auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager(); - return ConvertORTTensorMapToNumpy(opt_state_tensors, data_transfer_manager); - }) - .def("get_partition_info_map", [](PyTrainingSession* sess) { - std::unordered_map>> part_info_map; - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->GetPartitionInfoMap(part_info_map)); - return part_info_map; - }) - .def("load_state", [](PyTrainingSession* sess, std::unordered_map& state, bool strict) { - NameMLValMap state_tensors; - for (auto initializer : state) { - OrtValue ml_value; - auto px = sess->GetSessionHandle()->GetModelInputs(); - if (!px.first.IsOK() || !px.second) { - throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null"); - } - CreateGenericMLValue(px.second, GetAllocator(), initializer.first, initializer.second, &ml_value); - ThrowIfPyErrOccured(); - state_tensors.insert(std::make_pair(initializer.first, ml_value)); - } - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->SetStateTensors(state_tensors, strict)); - }) - .def("is_output_fp32_node", [](PyTrainingSession* sess, const std::string& output_name) { - return static_cast(sess->GetSessionHandle())->IsGraphOutputFp32Node(output_name); - }); - py::class_(m, "PartialGraphExecutionState") .def(py::init([]() { return std::make_unique(); diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index 88ef90a7feaa8..4d1db7334f280 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -40,7 +40,7 @@ const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM* void addGlobalMethods(py::module& m); void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn); -void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn); +void addObjectMethodsForTraining(py::module& m); void addObjectMethodsForEager(py::module& m); #ifdef ENABLE_LAZY_TENSOR void addObjectMethodsForLazyTensor(py::module& m); @@ -339,7 +339,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { } #endif - addObjectMethodsForTraining(m, ORTTrainingRegisterExecutionProviders); + addObjectMethodsForTraining(m); #ifdef ENABLE_LAZY_TENSOR addObjectMethodsForLazyTensor(m); diff --git a/orttraining/orttraining/python/training/__init__.py b/orttraining/orttraining/python/training/__init__.py index 73b1f826f68e1..a3c22686a1039 100644 --- a/orttraining/orttraining/python/training/__init__.py +++ b/orttraining/orttraining/python/training/__init__.py @@ -8,26 +8,16 @@ TrainingParameters, is_ortmodule_available, ) -from onnxruntime.capi.training.training_session import TrainingSession - # Options need to be imported before `ORTTrainer`. -from .orttrainer_options import ORTTrainerOptions -from .orttrainer import ORTTrainer, TrainStepInfo -from . import amp, artifacts, checkpoint, model_desc_validation, optim +from . import amp, artifacts, optim __all__ = [ "PropagateCastOpsStrategy", "TrainingParameters", "is_ortmodule_available", - "TrainingSession", - "ORTTrainerOptions", - "ORTTrainer", - "TrainStepInfo", "amp", "artifacts", - "checkpoint", - "model_desc_validation", "optim", ] diff --git a/orttraining/orttraining/python/training/_checkpoint_storage.py b/orttraining/orttraining/python/training/_checkpoint_storage.py deleted file mode 100644 index 7a8ada7dee96b..0000000000000 --- a/orttraining/orttraining/python/training/_checkpoint_storage.py +++ /dev/null @@ -1,107 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import pickle -from collections.abc import Mapping - -import h5py - - -def _dfs_save(group, save_obj): - """Recursively go over each level in the save_obj dictionary and save values to a hdf5 group""" - - for key, value in save_obj.items(): - if isinstance(value, Mapping): - subgroup = group.create_group(key) - _dfs_save(subgroup, value) - else: - group[key] = value - - -def save(save_obj: dict, path): - """Persists the input dictionary to a file specified by path. - - Saves an hdf5 representation of the save_obj dictionary to a file or a file-like object specified by path. - Values are saved in a format supported by h5py. For example, a PyTorch tensor is saved and loaded as a - numpy object. So, user types may be converted from their original types to numpy equivalent types. - - Args: - save_obj: dictionary that needs to be saved. - save_obj should consist of types supported by hdf5 file format. - if hdf5 does not recognize a type, an exception is raised. - if save_obj is not a dictionary, a ValueError is raised. - path: string representation to a file path or a python file-like object. - if file already exists at path, an exception is raised. - """ - if not isinstance(save_obj, Mapping): - raise ValueError("Object to be saved must be a dictionary") - - with h5py.File(path, "w-") as f: - _dfs_save(f, save_obj) - - -def _dfs_load(group, load_obj): - """Recursively go over each level in the hdf5 group and load the values into the given dictionary""" - - for key in group: - if isinstance(group[key], h5py.Group): - load_obj[key] = {} - _dfs_load(group[key], load_obj[key]) - else: - load_obj[key] = group[key][()] - - -def load(path, key=None): - """Loads the data stored in the binary file specified at the given path into a dictionary and returns it. - - Loads the data from an hdf5 file specified at the given path into a python dictionary. - Loaded dictionary contains numpy equivalents of python data types. For example: - PyTorch tensor -> saved as a numpy array and loaded as a numpy array. - bool -> saved as a numpy bool and loaded as a numpy bool - If a '/' separated key is provided, the value at that hierarchical level in the hdf5 group is returned. - - Args: - path: string representation to a file path or a python file-like object. - if file does not already exist at path, an exception is raised. - key: '/' separated representation of the hierarchy level value that needs to be returned/ - for example, if the saved binary file has structure {a: {b: x, c:y}} and the user would like - to query the value for c, the key provided should be 'a/c'. - the default value of None for key implies that the entire hdf5 file structure needs to be loaded into a dictionary and returned. - - Returns: - a dictionary loaded from the specified binary hdf5 file. - """ - if not h5py.is_hdf5(path): - raise ValueError(f"{path} is not an hdf5 file or a python file-like object.") - - load_obj = {} - with h5py.File(path, "r") as f: - if key: - f = f[key] # noqa: PLW2901 - if isinstance(f, h5py.Dataset): - return f[()] - - _dfs_load(f, load_obj) - - return load_obj - - -def to_serialized_hex(user_dict): - """Serialize the user_dict and convert the serialized bytes to a hex string and return""" - - return pickle.dumps(user_dict).hex() - - -def from_serialized_hex(serialized_hex): - """Convert serialized_hex to bytes and deserialize it and return""" - - # serialized_hex can be either a regular string or a byte string. - # if it is a byte string, convert to regular string using decode() - # if it is a regular string, do nothing to it - try: # noqa: SIM105 - serialized_hex = serialized_hex.decode() - except AttributeError: - pass - return pickle.loads(bytes.fromhex(serialized_hex)) diff --git a/orttraining/orttraining/python/training/_utils.py b/orttraining/orttraining/python/training/_utils.py index 4eb79443c8f1a..091274d1d171d 100644 --- a/orttraining/orttraining/python/training/_utils.py +++ b/orttraining/orttraining/python/training/_utils.py @@ -6,11 +6,9 @@ import importlib.util import os import sys -from functools import wraps # noqa: F401 import numpy as np import torch -from onnx import TensorProto # noqa: F401 from packaging.version import Version @@ -23,16 +21,6 @@ def get_device_index(device): return 0 if device.index is None else device.index -def get_device_index_from_input(input): - """Returns device index from a input PyTorch Tensor""" - - if isinstance(input, (list, tuple)): - device_index = get_device_index(input[0].device) - else: - device_index = get_device_index(input.device) - return device_index - - def get_device_str(device): if isinstance(device, str): # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 @@ -50,24 +38,6 @@ def get_device_str(device): return device -def get_all_gradients_finite_name_from_session(session): - """Find all_gradients_finite node on Session graph and return its name""" - - nodes = [x for x in session._outputs_meta if "all_gradients_finite" in x.name] - if len(nodes) != 1: - raise RuntimeError("'all_gradients_finite' node not found within training session") - return nodes[0].name - - -def get_gradient_accumulation_name_from_session(session): - """Find Group_Accumulated_Gradients node on Session graph and return its name""" - - nodes = [x for x in session._outputs_meta if "Group_Accumulated_Gradients" in x.name] - if len(nodes) != 1: - raise RuntimeError("'Group_Accumulated_Gradients' node not found within training session") - return nodes[0].name - - def dtype_torch_to_numpy(torch_dtype): """Converts PyTorch types to Numpy types @@ -232,111 +202,3 @@ def import_module_from_file(file_path, module_name=None): sys.modules[module_name] = module spec.loader.exec_module(module) return module - - -def state_dict_model_key(): - """Returns the model key name in the state dictionary""" - - return "model" - - -def state_dict_optimizer_key(): - """Returns the optimizer key name in the state dictionary""" - - return "optimizer" - - -def state_dict_partition_info_key(): - """Returns the partition info key name in the state dictionary""" - - return "partition_info" - - -def state_dict_trainer_options_key(): - """Returns the trainer options key name in the state dictionary""" - - return "trainer_options" - - -def state_dict_full_precision_key(): - """Returns the full precision key name in the state dictionary""" - - return "full_precision" - - -def state_dict_original_dimension_key(): - """Returns the original dimension key name in the state dictionary""" - - return "original_dim" - - -def state_dict_sharded_optimizer_keys(): - """Returns the optimizer key names that can be sharded in the state dictionary""" - - return {"Moment_1", "Moment_2"} - - -def state_dict_user_dict_key(): - """Returns the user dict key name in the state dictionary""" - - return "user_dict" - - -def state_dict_trainer_options_mixed_precision_key(): - """Returns the trainer options mixed precision key name in the state dictionary""" - - return "mixed_precision" - - -def state_dict_trainer_options_zero_stage_key(): - """Returns the trainer options zero_stage key name in the state dictionary""" - - return "zero_stage" - - -def state_dict_trainer_options_world_rank_key(): - """Returns the trainer options world_rank key name in the state dictionary""" - - return "world_rank" - - -def state_dict_trainer_options_world_size_key(): - """Returns the trainer options world_size key name in the state dictionary""" - - return "world_size" - - -def state_dict_trainer_options_data_parallel_size_key(): - """Returns the trainer options data_parallel_size key name in the state dictionary""" - - return "data_parallel_size" - - -def state_dict_trainer_options_horizontal_parallel_size_key(): - """Returns the trainer options horizontal_parallel_size key name in the state dictionary""" - - return "horizontal_parallel_size" - - -def state_dict_trainer_options_optimizer_name_key(): - """Returns the trainer options optimizer_name key name in the state dictionary""" - - return "optimizer_name" - - -def state_dict_train_step_info_key(): - """Returns the train step info key name in the state dictionary""" - - return "train_step_info" - - -def state_dict_train_step_info_optimization_step_key(): - """Returns the train step info optimization step key name in the state dictionary""" - - return "optimization_step" - - -def state_dict_train_step_info_step_key(): - """Returns the train step info step key name in the state dictionary""" - - return "step" diff --git a/orttraining/orttraining/python/training/checkpoint.py b/orttraining/orttraining/python/training/checkpoint.py deleted file mode 100644 index d0ff0650662b7..0000000000000 --- a/orttraining/orttraining/python/training/checkpoint.py +++ /dev/null @@ -1,748 +0,0 @@ -import os -import tempfile -import warnings -from enum import Enum - -import numpy as np -import onnx -import torch - -from . import _checkpoint_storage, _utils - -################################################################################ -# Experimental Checkpoint APIs -################################################################################ - - -def experimental_state_dict(ort_trainer, include_optimizer_state=True): - warnings.warn( - "experimental_state_dict() will be deprecated soon. Please use ORTTrainer.state_dict() instead.", - DeprecationWarning, - ) - - if not ort_trainer._training_session: - warnings.warn( - "ONNX Runtime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling state_dict()." - ) - return ort_trainer._state_dict - - # extract trained weights - session_state = ort_trainer._training_session.get_state() - torch_state = {} - for name in session_state: - torch_state[name] = torch.from_numpy(session_state[name]) - - # extract untrained weights and buffer - for n in ort_trainer._onnx_model.graph.initializer: - if n.name not in torch_state and n.name in ort_trainer.options.utils.frozen_weights: - torch_state[n.name] = torch.from_numpy(np.array(onnx.numpy_helper.to_array(n))) - - # Need to remove redundant (optimizer) initializers to map back to original torch state names - if not include_optimizer_state and ort_trainer._torch_state_dict_keys: - return {key: torch_state[key] for key in ort_trainer._torch_state_dict_keys if key in torch_state} - return torch_state - - -def experimental_load_state_dict(ort_trainer, state_dict, strict=False): - warnings.warn( - "experimental_load_state_dict() will be deprecated soon. Please use ORTTrainer.load_state_dict() instead.", - DeprecationWarning, - ) - - # Note: It may happen ONNX model has not yet been initialized - # In this case we cache a reference to desired state and delay the restore until after initialization - # Unexpected behavior will result if the user changes the reference before initialization - if not ort_trainer._training_session: - ort_trainer._state_dict = state_dict - ort_trainer._load_state_dict_strict = strict - return - - # Update onnx model from loaded state dict - cur_initializers_names = [n.name for n in ort_trainer._onnx_model.graph.initializer] - new_initializers = {} - - for name in state_dict: - if name in cur_initializers_names: - new_initializers[name] = state_dict[name].numpy() - elif strict: - raise RuntimeError(f"Checkpoint tensor: {name} is not present in the model.") - - ort_trainer._update_onnx_model_initializers(new_initializers) - - # create new session based on updated onnx model - ort_trainer._state_dict = None - ort_trainer._init_session() - - # load training state - session_state = {name: state_dict[name].numpy() for name in state_dict} - ort_trainer._training_session.load_state(session_state, strict) - - -def experimental_save_checkpoint( - ort_trainer, - checkpoint_dir, - checkpoint_prefix="ORT_checkpoint", - checkpoint_state_dict=None, - include_optimizer_state=True, -): - warnings.warn( - "experimental_save_checkpoint() will be deprecated soon. Please use ORTTrainer.save_checkpoint() instead.", - DeprecationWarning, - ) - - if checkpoint_state_dict is None: - checkpoint_state_dict = {"model": experimental_state_dict(ort_trainer, include_optimizer_state)} - else: - checkpoint_state_dict.update({"model": experimental_state_dict(ort_trainer, include_optimizer_state)}) - - assert os.path.exists(checkpoint_dir), f"checkpoint_dir ({checkpoint_dir}) directory doesn't exist" - - checkpoint_name = _get_checkpoint_name( - checkpoint_prefix, - ort_trainer.options.distributed.deepspeed_zero_optimization.stage, - ort_trainer.options.distributed.world_rank, - ort_trainer.options.distributed.world_size, - ) - checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) - if os.path.exists(checkpoint_file): - msg = f"{checkpoint_file} already exists, overwriting." - warnings.warn(msg) - torch.save(checkpoint_state_dict, checkpoint_file) - - -def experimental_load_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False): - warnings.warn( - "experimental_load_checkpoint() will be deprecated soon. Please use ORTTrainer.load_checkpoint() instead.", - DeprecationWarning, - ) - - checkpoint_files = _list_checkpoint_files(checkpoint_dir, checkpoint_prefix) - is_partitioned = False - if len(checkpoint_files) > 1: - msg = ( - f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." - " Attempting to load ZeRO checkpoint." - ) - warnings.warn(msg) - is_partitioned = True - if (not ort_trainer.options.distributed.deepspeed_zero_optimization.stage) and is_partitioned: - return _load_multi_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, strict) - else: - return _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict) - - -class _AGGREGATION_MODE(Enum): # noqa: N801 - Zero = 0 - Megatron = 1 - - -def _order_paths(paths, D_groups, H_groups): - """Reorders the given paths in order of aggregation of ranks for D and H parallellism respectively - and returns the ordered dict""" - - trainer_options_path_tuples = [] - world_rank = _utils.state_dict_trainer_options_world_rank_key() - - for path in paths: - trainer_options_path_tuples.append( - (_checkpoint_storage.load(path, key=_utils.state_dict_trainer_options_key()), path) - ) - - # sort paths according to rank - sorted_paths = [ - path - for _, path in sorted( - trainer_options_path_tuples, key=lambda trainer_options_path_pair: trainer_options_path_pair[0][world_rank] - ) - ] - - ordered_paths = dict() - ordered_paths["D"] = [[sorted_paths[i] for i in D_groups[group_id]] for group_id in range(len(D_groups))] - ordered_paths["H"] = [[sorted_paths[i] for i in H_groups[group_id]] for group_id in range(len(H_groups))] - - return ordered_paths - - -def _add_or_update_sharded_key( - state_key, state_value, state_sub_dict, model_state_key, state_partition_info, sharded_states_original_dims, mode -): - """Add or update the record for the sharded state_key in the state_sub_dict""" - - # record the original dimension for this state - original_dim = _utils.state_dict_original_dimension_key() - sharded_states_original_dims[model_state_key] = state_partition_info[original_dim] - - axis = 0 - if mode == _AGGREGATION_MODE.Megatron and state_partition_info["megatron_row_partition"] == 0: - axis = -1 - - if state_key in state_sub_dict: - # state_dict already contains a record for this state - # since this state is sharded, concatenate the state value to - # the record in the state_dict - state_sub_dict[state_key] = np.concatenate((state_sub_dict[state_key], state_value), axis) - else: - # create a new entry for this state in the state_dict - state_sub_dict[state_key] = state_value - - -def _add_or_validate_unsharded_key(state_key, state_value, state_sub_dict, mismatch_error_string): - """Add or validate the record for the unsharded state_key in the state_sub_dict""" - - if state_key in state_sub_dict: - # state_dict already contains a record for this unsharded state. - # assert that all values are the same for this previously loaded state - assert (state_sub_dict[state_key] == state_value).all(), mismatch_error_string - else: - # create a new entry for this state in the state_sub_dict - state_sub_dict[state_key] = state_value - - -def _aggregate_model_states( - rank_state_dict, sharded_states_original_dims, state_dict, mixed_precision_enabled, mode=_AGGREGATION_MODE.Zero -): - """Aggregates all model states from the rank_state_dict into state_dict""" - - model = _utils.state_dict_model_key() - full_precision = _utils.state_dict_full_precision_key() - partition_info = _utils.state_dict_partition_info_key() - - # if there are no model states in the rank_state_dict, no model aggregation is needed - if model not in rank_state_dict: - return - - if model not in state_dict: - state_dict[model] = {} - - if full_precision not in state_dict[model]: - state_dict[model][full_precision] = {} - - # iterate over all model state keys - for model_state_key, model_state_value in rank_state_dict[model][full_precision].items(): - # ZERO: full precision model states are sharded only when they exist in the partition_info subdict and mixed - # precision training was enabled. for full precision training, full precision model states are not sharded - # MEGATRON : full precision model states are sharded when they exist in the partition_info subdict - if (model_state_key in rank_state_dict[partition_info]) and ( - mode == _AGGREGATION_MODE.Megatron or mixed_precision_enabled - ): - # this model state is sharded - _add_or_update_sharded_key( - model_state_key, - model_state_value, - state_dict[model][full_precision], - model_state_key, - rank_state_dict[partition_info][model_state_key], - sharded_states_original_dims, - mode, - ) - else: - # this model state is not sharded since a record for it does not exist in the partition_info subdict - _add_or_validate_unsharded_key( - model_state_key, - model_state_value, - state_dict[model][full_precision], - f"Value mismatch for model state {model_state_key}", - ) - - -def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode=_AGGREGATION_MODE.Zero): - """Aggregates all optimizer states from the rank_state_dict into state_dict""" - - optimizer = _utils.state_dict_optimizer_key() - partition_info = _utils.state_dict_partition_info_key() - sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys() - - # if there are no optimizer states in the rank_state_dict, no optimizer aggregation is needed - if optimizer not in rank_state_dict: - return - - if optimizer not in state_dict: - state_dict[optimizer] = {} - - # iterate over all optimizer state keys - for model_state_key, optimizer_dict in rank_state_dict[optimizer].items(): - for optimizer_key, optimizer_value in optimizer_dict.items(): - if model_state_key not in state_dict[optimizer]: - state_dict[optimizer][model_state_key] = {} - - if optimizer_key in sharded_optimizer_keys and model_state_key in rank_state_dict[partition_info]: - # this optimizer state is sharded since a record exists in the partition_info subdict - _add_or_update_sharded_key( - optimizer_key, - optimizer_value, - state_dict[optimizer][model_state_key], - model_state_key, - rank_state_dict[partition_info][model_state_key], - sharded_states_original_dims, - mode, - ) - else: - # this optimizer state is not sharded since a record for it does not exist in the partition_info subdict - # or this optimizer key is not one of the sharded optimizer keys - _add_or_validate_unsharded_key( - optimizer_key, - optimizer_value, - state_dict[optimizer][model_state_key], - f"Value mismatch for model state {model_state_key} and optimizer state {optimizer_key}", - ) - - -def _reshape_states(sharded_states_original_dims, state_dict, mixed_precision_enabled): - """Reshape model and optimizer states in the state_dict according to dimensions in sharded_states_original_dims""" - - model = _utils.state_dict_model_key() - full_precision = _utils.state_dict_full_precision_key() - optimizer = _utils.state_dict_optimizer_key() - sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys() - - for sharded_state_key, original_dim in sharded_states_original_dims.items(): - # reshape model states to original_dim only when mixed precision is enabled - if mixed_precision_enabled and (model in state_dict): - state_dict[model][full_precision][sharded_state_key] = state_dict[model][full_precision][ - sharded_state_key - ].reshape(original_dim) - - # reshape optimizer states to original_dim - if optimizer in state_dict: - for optimizer_key, optimizer_value in state_dict[optimizer][sharded_state_key].items(): - if optimizer_key in sharded_optimizer_keys: - state_dict[optimizer][sharded_state_key][optimizer_key] = optimizer_value.reshape(original_dim) - - -def _aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation): - """Extracts trainer options from rank_state_dict and loads them accordingly on state_dict""" - trainer_options = _utils.state_dict_trainer_options_key() - state_dict[trainer_options] = {} - - mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key() - zero_stage = _utils.state_dict_trainer_options_zero_stage_key() - world_rank = _utils.state_dict_trainer_options_world_rank_key() - world_size = _utils.state_dict_trainer_options_world_size_key() - optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key() - D_size = _utils.state_dict_trainer_options_data_parallel_size_key() # noqa: N806 - H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() # noqa: N806 - - state_dict[trainer_options][mixed_precision] = rank_state_dict[trainer_options][mixed_precision] - state_dict[trainer_options][zero_stage] = 0 - state_dict[trainer_options][world_rank] = rank_state_dict[trainer_options][world_rank] if partial_aggregation else 0 - state_dict[trainer_options][world_size] = 1 - state_dict[trainer_options][optimizer_name] = rank_state_dict[trainer_options][optimizer_name] - state_dict[trainer_options][D_size] = 1 - state_dict[trainer_options][H_size] = 1 - - -def _aggregate_megatron_partition_info(rank_state_dict, state_dict): - """Extracts partition_info from rank_state_dict and loads on state_dict for megatron-partitioned weights""" - partition_info = _utils.state_dict_partition_info_key() - if partition_info not in state_dict: - state_dict[partition_info] = {} - - rank_partition_info = rank_state_dict[partition_info] - for model_state_key, partition_info_dict in rank_partition_info.items(): - if model_state_key not in state_dict[partition_info]: - # add partition info only if weight is megatron partitioned - if partition_info_dict["megatron_row_partition"] >= 0: - state_dict[partition_info][model_state_key] = partition_info_dict - - -def _to_pytorch_format(state_dict): - """Convert ORT state dictionary schema (hierarchical structure) to PyTorch state dictionary schema (flat structure)""" - - pytorch_state_dict = {} - for model_state_key, model_state_value in state_dict[_utils.state_dict_model_key()][ - _utils.state_dict_full_precision_key() - ].items(): - # convert numpy array to a torch tensor - pytorch_state_dict[model_state_key] = torch.tensor(model_state_value) - return pytorch_state_dict - - -def _get_parallellism_groups(data_parallel_size, horizontal_parallel_size, world_size): - """Returns the D and H groups for the given sizes""" - num_data_groups = world_size // data_parallel_size - data_groups = [] - for data_group_id in range(num_data_groups): - data_group_ranks = [] - for r in range(data_parallel_size): - data_group_ranks.append(data_group_id + horizontal_parallel_size * r) - data_groups.append(data_group_ranks) - - num_horizontal_groups = world_size // horizontal_parallel_size - horizontal_groups = [] - for hori_group_id in range(num_horizontal_groups): - hori_group_ranks = [] - for r in range(horizontal_parallel_size): - hori_group_ranks.append(hori_group_id * horizontal_parallel_size + r) - horizontal_groups.append(hori_group_ranks) - - return data_groups, horizontal_groups - - -def _aggregate_over_ranks( - ordered_paths, - ranks, - sharded_states_original_dims=None, - mode=_AGGREGATION_MODE.Zero, - partial_aggregation=False, - pytorch_format=True, -): - """Aggregate checkpoint files over set of ranks and return a single state dictionary - - Args: - ordered_paths: list of paths in the order in which they must be aggregated - ranks: list of ranks that are to be aggregated - sharded_states_original_dims: dict containing the original dims for sharded states that are persisted over - multiple calls to _aggregate_over_ranks() - mode: mode of aggregation: Zero or Megatron - partial_aggregation: boolean flag to indicate whether to produce a partially - aggregated state which can be further aggregated over - pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict - Returns: - state_dict that can be loaded into an ORTTrainer or into a PyTorch model - """ - state_dict = {} - if sharded_states_original_dims is None: - sharded_states_original_dims = dict() - world_rank = _utils.state_dict_trainer_options_world_rank_key() - mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key() - zero_stage = _utils.state_dict_trainer_options_zero_stage_key() - world_size = _utils.state_dict_trainer_options_world_size_key() - optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key() - - loaded_mixed_precision = None - loaded_world_size = None - loaded_zero_stage = None - loaded_optimizer_name = None - - for i, path in enumerate(ordered_paths): - rank_state_dict = _checkpoint_storage.load(path) - - assert _utils.state_dict_partition_info_key() in rank_state_dict, "Missing information: partition_info" - assert _utils.state_dict_trainer_options_key() in rank_state_dict, "Missing information: trainer_options" - assert ( - ranks[i] == rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank] - ), "Unexpected rank in file at path {}. Expected {}, got {}".format( - path, rank, rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank] # noqa: F821 - ) - if loaded_mixed_precision is None: - loaded_mixed_precision = rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] - else: - assert ( - loaded_mixed_precision == rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] - ), f"Mixed precision state mismatch among checkpoint files. File: {path}" - if loaded_world_size is None: - loaded_world_size = rank_state_dict[_utils.state_dict_trainer_options_key()][world_size] - else: - assert ( - loaded_world_size == rank_state_dict[_utils.state_dict_trainer_options_key()][world_size] - ), f"World size state mismatch among checkpoint files. File: {path}" - if loaded_zero_stage is None: - loaded_zero_stage = rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage] - else: - assert ( - loaded_zero_stage == rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage] - ), f"Zero stage mismatch among checkpoint files. File: {path}" - if loaded_optimizer_name is None: - loaded_optimizer_name = rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] - else: - assert ( - loaded_optimizer_name == rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] - ), f"Optimizer name mismatch among checkpoint files. File: {path}" - - # aggregate all model states - _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, loaded_mixed_precision, mode) - - if not pytorch_format: - # aggregate all optimizer states if pytorch_format is False - _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode) - - # for D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups - # to aggregate over Zero, and another pass to aggregate Megatron partitioned - # states. Preserve the relevant partition info only for weights that are megatron partitioned for - # a partial aggregation call - if partial_aggregation: - _aggregate_megatron_partition_info(rank_state_dict, state_dict) - - # entry for trainer_options in the state_dict to perform other sanity checks - if _utils.state_dict_trainer_options_key() not in state_dict: - _aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation) - - # entry for user_dict in the state_dict if not already present - if ( - _utils.state_dict_user_dict_key() not in state_dict - and _utils.state_dict_user_dict_key() in rank_state_dict - ): - state_dict[_utils.state_dict_user_dict_key()] = rank_state_dict[_utils.state_dict_user_dict_key()] - - # for a partial aggregation scenario, we might not have the entire tensor aggregated yet, thus skip reshape - if not partial_aggregation: - # reshape all the sharded tensors based on the original dimensions stored in sharded_states_original_dims - _reshape_states(sharded_states_original_dims, state_dict, loaded_mixed_precision) - - # return a flat structure for PyTorch model in case pytorch_format is True - # else return the hierarchical structure for ORTTrainer - return _to_pytorch_format(state_dict) if pytorch_format else state_dict - - -def _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format): # noqa: N802 - """Aggregate checkpoint files and return a single state dictionary for the D+H - (Zero+Megatron) partitioning strategy. - For D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups - to aggregate over Zero, and another pass over the previously aggregated states - to aggregate Megatron partitioned states. - """ - sharded_states_original_dims = {} - aggregate_data_checkpoint_files = [] - - # combine for Zero over data groups and save to temp file - with tempfile.TemporaryDirectory() as save_dir: - for group_id, d_group in enumerate(D_groups): - aggregate_state_dict = _aggregate_over_ranks( - ordered_paths["D"][group_id], - d_group, - sharded_states_original_dims, - partial_aggregation=True, - pytorch_format=False, - ) - - filename = "ort.data_group." + str(group_id) + ".ort.pt" - filepath = os.path.join(save_dir, filename) - _checkpoint_storage.save(aggregate_state_dict, filepath) - aggregate_data_checkpoint_files.append(filepath) - - assert len(aggregate_data_checkpoint_files) > 0 - - # combine for megatron: - aggregate_state = _aggregate_over_ranks( - aggregate_data_checkpoint_files, - H_groups[0], - sharded_states_original_dims, - mode=_AGGREGATION_MODE.Megatron, - pytorch_format=pytorch_format, - ) - - return aggregate_state - - -def aggregate_checkpoints(paths, pytorch_format=True): - """Aggregate checkpoint files and return a single state dictionary - - Aggregates checkpoint files specified by paths and loads them one at a time, merging - them into a single state dictionary. - The checkpoint files represented by paths must be saved through ORTTrainer.save_checkpoint() function. - The schema of the state_dict returned will be in the same as the one returned by ORTTrainer.state_dict() - - Args: - paths: list of more than one file represented as strings where the checkpoint is saved - pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict - Returns: - state_dict that can be loaded into an ORTTrainer or into a PyTorch model - """ - - loaded_trainer_options = _checkpoint_storage.load(paths[0], key=_utils.state_dict_trainer_options_key()) - D_size = _utils.state_dict_trainer_options_data_parallel_size_key() # noqa: N806 - H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() # noqa: N806 - world_size = _utils.state_dict_trainer_options_world_size_key() - - D_size = loaded_trainer_options[D_size] # noqa: N806 - H_size = loaded_trainer_options[H_size] # noqa: N806 - world_size = loaded_trainer_options[world_size] - D_groups, H_groups = _get_parallellism_groups(D_size, H_size, world_size) # noqa: N806 - - combine_zero = loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 - combine_megatron = len(H_groups[0]) > 1 - - # order the paths in the order of groups in which they must be aggregated according to - # data-parallel groups and H-parallel groups obtained - # eg: {'D': [[path_0, path_2],[path_1, path_3]], 'H': [[path_0, path_1],[path_2, path_3]]} - ordered_paths = _order_paths(paths, D_groups, H_groups) - - aggregate_state = None - if combine_zero and combine_megatron: - aggregate_state = _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format) - elif combine_zero: - aggregate_state = _aggregate_over_ranks( - ordered_paths["D"][0], D_groups[0], mode=_AGGREGATION_MODE.Zero, pytorch_format=pytorch_format - ) - elif combine_megatron: - aggregate_state = _aggregate_over_ranks( - ordered_paths["H"][0], H_groups[0], mode=_AGGREGATION_MODE.Megatron, pytorch_format=pytorch_format - ) - - return aggregate_state - - -################################################################################ -# Helper functions -################################################################################ - - -def _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict): - checkpoint_name = _get_checkpoint_name( - checkpoint_prefix, - is_partitioned, - ort_trainer.options.distributed.world_rank, - ort_trainer.options.distributed.world_size, - ) - checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) - - if is_partitioned: - assert_msg = ( - f"Couldn't find checkpoint file {checkpoint_file}." - " Optimizer partitioning is enabled using ZeRO. Please make sure the checkpoint file exists " - f"for rank {ort_trainer.options.distributed.world_rank} of {ort_trainer.options.distributed.world_size}" - ) - else: - assert_msg = f"Couldn't find checkpoint file {checkpoint_file}." - assert os.path.exists(checkpoint_file), assert_msg - - checkpoint_state = torch.load(checkpoint_file, map_location="cpu") - experimental_load_state_dict(ort_trainer, checkpoint_state["model"], strict=strict) - del checkpoint_state["model"] - return checkpoint_state - - -def _load_multi_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, strict): - checkpoint_files = _list_checkpoint_files(checkpoint_dir, checkpoint_prefix) - - ckpt_agg = _CombineZeroCheckpoint(checkpoint_files) - aggregate_state_dict = ckpt_agg.aggregate_checkpoints() - - experimental_load_state_dict(ort_trainer, aggregate_state_dict, strict=strict) - - # aggregate other keys in the state_dict. - # Values will be overwritten for matching keys among workers - all_checkpoint_states = dict() - for checkpoint_file in checkpoint_files: - checkpoint_state = torch.load(checkpoint_file, map_location="cpu") - del checkpoint_state["model"] - all_checkpoint_states.update(checkpoint_state) - return all_checkpoint_states - - -def _list_checkpoint_files(checkpoint_dir, checkpoint_prefix, extension=".ort.pt"): - ckpt_file_names = [f for f in os.listdir(checkpoint_dir) if f.startswith(checkpoint_prefix)] - ckpt_file_names = [f for f in ckpt_file_names if f.endswith(extension)] - ckpt_file_names = [os.path.join(checkpoint_dir, f) for f in ckpt_file_names] - - assert len(ckpt_file_names) > 0, f"No checkpoint found with prefix '{checkpoint_prefix}' at '{checkpoint_dir}'" - return ckpt_file_names - - -def _get_checkpoint_name(prefix, is_partitioned, world_rank=None, world_size=None): - SINGLE_CHECKPOINT_FILENAME = "{prefix}.ort.pt" # noqa: N806 - MULTIPLE_CHECKPOINT_FILENAME = "{prefix}.ZeRO.{world_rank}.{world_size}.ort.pt" # noqa: N806 - - if is_partitioned: - filename = MULTIPLE_CHECKPOINT_FILENAME.format( - prefix=prefix, world_rank=world_rank, world_size=(world_size - 1) - ) - else: - filename = SINGLE_CHECKPOINT_FILENAME.format(prefix=prefix) - return filename - - -def _split_state_dict(state_dict): - optimizer_keys = ["Moment_1_", "Moment_2_", "Update_Count_", "Step"] - split_sd = {"optimizer": {}, "fp32_param": {}, "fp16_param": {}} - for k, v in state_dict.items(): - mode = "fp32_param" - for optim_key in optimizer_keys: - if k.startswith(optim_key): - mode = "optimizer" - break - if k.endswith("_fp16"): - mode = "fp16_param" - split_sd[mode][k] = v - return split_sd - - -class _CombineZeroCheckpoint: - def __init__(self, checkpoint_files, clean_state_dict=None): - assert len(checkpoint_files) > 0, "No checkpoint files passed" - self.checkpoint_files = checkpoint_files - self.clean_state_dict = clean_state_dict - self.world_size = int(self.checkpoint_files[0].split("ZeRO")[1].split(".")[2]) + 1 - assert len(self.checkpoint_files) == self.world_size, f"Could not find {self.world_size} files" - self.weight_shape_map = {} - self.sharded_params = set() - - def _split_name(self, name: str): - name_split = name.split("_view_") - view_num = None - if len(name_split) > 1: - view_num = int(name_split[1]) - optimizer_key = "" - mp_suffix = "" - if name_split[0].startswith("Moment_1"): - optimizer_key = "Moment_1_" - elif name_split[0].startswith("Moment_2"): - optimizer_key = "Moment_2_" - elif name_split[0].startswith("Update_Count"): - optimizer_key = "Update_Count_" - elif name_split[0].endswith("_fp16"): - mp_suffix = "_fp16" - param_name = name_split[0] - if optimizer_key: - param_name = param_name.split(optimizer_key)[1] - param_name = param_name.split("_fp16")[0] - return param_name, optimizer_key, view_num, mp_suffix - - def _update_weight_statistics(self, name, value): - if name not in self.weight_shape_map: - self.weight_shape_map[name] = value.size() # original shape of tensor - - def _reshape_tensor(self, key): - value = self.aggregate_state_dict[key] - weight_name, _, _, _ = self._split_name(key) - set_size = self.weight_shape_map[weight_name] - self.aggregate_state_dict[key] = value.reshape(set_size) - - def _aggregate(self, param_dict): - for k, v in param_dict.items(): - weight_name, optimizer_key, view_num, mp_suffix = self._split_name(k) - if view_num is not None: - # parameter is sharded - param_name = optimizer_key + weight_name + mp_suffix - - if param_name in self.aggregate_state_dict and optimizer_key not in ["Update_Count_"]: - self.sharded_params.add(param_name) - # Found a previous shard of the param, concatenate shards ordered by ranks - self.aggregate_state_dict[param_name] = torch.cat((self.aggregate_state_dict[param_name], v)) - else: - self.aggregate_state_dict[param_name] = v - else: - if k in self.aggregate_state_dict: - assert (self.aggregate_state_dict[k] == v).all(), "Unsharded params must have the same value" - else: - self.aggregate_state_dict[k] = v - self._update_weight_statistics(weight_name, v) - - def aggregate_checkpoints(self): - warnings.warn( - "_CombineZeroCheckpoint.aggregate_checkpoints() will be deprecated soon. " - "Please use aggregate_checkpoints() instead.", - DeprecationWarning, - ) - - checkpoint_prefix = self.checkpoint_files[0].split(".ZeRO")[0] - self.aggregate_state_dict = dict() - - for i in range(self.world_size): - checkpoint_name = _get_checkpoint_name(checkpoint_prefix, True, i, self.world_size) - rank_state_dict = torch.load(checkpoint_name, map_location=torch.device("cpu")) - if "model" in rank_state_dict: - rank_state_dict = rank_state_dict["model"] - - if self.clean_state_dict: - rank_state_dict = self.clean_state_dict(rank_state_dict) - - rank_state_dict = _split_state_dict(rank_state_dict) - self._aggregate(rank_state_dict["fp16_param"]) - self._aggregate(rank_state_dict["fp32_param"]) - self._aggregate(rank_state_dict["optimizer"]) - - for k in self.sharded_params: - self._reshape_tensor(k) - return self.aggregate_state_dict diff --git a/orttraining/orttraining/python/training/model_desc_validation.py b/orttraining/orttraining/python/training/model_desc_validation.py deleted file mode 100644 index dd3f4cb95cd59..0000000000000 --- a/orttraining/orttraining/python/training/model_desc_validation.py +++ /dev/null @@ -1,408 +0,0 @@ -from collections import namedtuple - -import cerberus -import torch - -from ._utils import static_vars - -LEARNING_RATE_IO_DESCRIPTION_NAME = "__learning_rate" -ALL_FINITE_IO_DESCRIPTION_NAME = "__all_finite" -LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME = "__loss_scale_input_name" -GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME = "__gradient_accumulation_name" - - -class _ORTTrainerModelDesc: - def __init__(self, model_desc): - # Keep a copy of original input for debug - self._original = dict(model_desc) - - # Global counter used to validate occurrences of 'is_loss=True' whithin 'model_desc.outputs' - # A stateless validator is used for each tuple, but validation accross the whole list of tuple is needed - # because just one 'is_loss=True' is allowed withing 'model_desc.outputs' list of tuples - _model_desc_outputs_validation.loss_counter = 0 - - # Used for logging purposes - self._main_class_name = self.__class__.__name__ - - # Validates user input - self._validated = dict(self._original) - validator = cerberus.Validator(MODEL_DESC_SCHEMA) - self._validated = validator.validated(self._validated) - if self._validated is None: - raise ValueError(f"Invalid model_desc: {validator.errors}") - - # Normalize inputs to a list of namedtuple(name, shape) - self._InputDescription = namedtuple("InputDescription", ["name", "shape"]) - self._InputDescriptionTyped = namedtuple("InputDescriptionTyped", ["name", "shape", "dtype"]) - for idx, input in enumerate(self._validated["inputs"]): - self._validated["inputs"][idx] = self._InputDescription(*input) - - # Normalize outputs to a list of namedtuple(name, shape, is_loss) - self._OutputDescription = namedtuple("OutputDescription", ["name", "shape", "is_loss"]) - self._OutputDescriptionTyped = namedtuple( - "OutputDescriptionTyped", ["name", "shape", "is_loss", "dtype", "dtype_amp"] - ) - for idx, output in enumerate(self._validated["outputs"]): - if len(output) == 2: - self._validated["outputs"][idx] = self._OutputDescription(*output, False) - else: - self._validated["outputs"][idx] = self._OutputDescription(*output) - - # Hard-code learning rate, all_finite descriptors - self.learning_rate = self._InputDescriptionTyped(LEARNING_RATE_IO_DESCRIPTION_NAME, [1], torch.float32) - - # Convert dict in object - for k, v in self._validated.items(): - setattr(self, k, self._wrap(v)) - - def __repr__(self): - """Pretty representation for a model description class""" - - pretty_msg = "Model description:\n" - - # Inputs - inputs = [] - for i_desc in self.inputs: - if isinstance(i_desc, self._InputDescription): - inputs.append(f"(name={i_desc.name}, shape={i_desc.shape})") - elif isinstance(i_desc, self._InputDescriptionTyped): - inputs.append(f"(name={i_desc.name}, shape={i_desc.shape}, dtype={i_desc.dtype})") - else: - raise ValueError(f"Unexpected type {type(i_desc)} for input description") - - pretty_msg += "\nInputs:" - for idx, item in enumerate(inputs): - pretty_msg += f"\n\t{idx}: {item}" - - # Outputs - outputs = [] - for o_desc in self.outputs: - if isinstance(o_desc, self._OutputDescription): - outputs.append(f"(name={o_desc.name}, shape={o_desc.shape})") - elif isinstance(o_desc, self._OutputDescriptionTyped): - outputs.append( - f"(name={o_desc.name}, shape={o_desc.shape}, dtype={o_desc.dtype}, dtype_amp={o_desc.dtype_amp})" - ) - else: - raise ValueError(f"Unexpected type {type(o_desc)} for output description") - pretty_msg += "\nOutputs:" - for idx, item in enumerate(outputs): - pretty_msg += f"\n\t{idx}: {item}" - - # Learning rate - if self.learning_rate: - pretty_msg += "\nLearning rate: " - pretty_msg += ( - f"(name={self.learning_rate.name}, shape={self.learning_rate.shape}, dtype={self.learning_rate.dtype})" - ) - - # Mixed precision - if getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None) or getattr( - self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None - ): - pretty_msg += "\nMixed Precision:" - if getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None): - pretty_msg += "\n\tis gradients finite: " - pretty_msg += ( - f"(name={self.all_finite.name}, shape={self.all_finite.shape}, dtype={self.all_finite.dtype})" - ) - if getattr(self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None): - pretty_msg += "\n\tloss scale input name: " - pretty_msg += f"(name={self.loss_scale_input.name}, shape={self.loss_scale_input.shape}, dtype={self.loss_scale_input.dtype})" - - # Gradient Accumulation steps - if self.gradient_accumulation: - pretty_msg += "\nGradient Accumulation: " - pretty_msg += f"(name={self.gradient_accumulation.name}, shape={self.gradient_accumulation.shape}, dtype={self.gradient_accumulation.dtype})" - - return pretty_msg - - def add_type_to_input_description(self, index, dtype): - """Updates an existing input description at position 'index' with 'dtype' type information - - Args: - index (int): position within 'inputs' description - dtype (torch.dtype): input data type - """ - - assert isinstance(index, int) and index >= 0, "input 'index' must be a positive int" - assert isinstance(dtype, torch.dtype), "input 'dtype' must be a torch.dtype type" - existing_values = (*self.inputs[index],) - if isinstance(self.inputs[index], self._InputDescriptionTyped): - existing_values = (*existing_values[:-1],) - self.inputs[index] = self._InputDescriptionTyped(*existing_values, dtype) - - def add_type_to_output_description(self, index, dtype, dtype_amp=None): - """Updates an existing output description at position 'index' with 'dtype' type information - - Args: - index (int): position within 'inputs' description - dtype (torch.dtype): input data type - dtype_amp (torch.dtype, default is None): input data type for evaluation with mixed precision - """ - - assert isinstance(index, int) and index >= 0, "output 'index' must be a positive int" - assert isinstance(dtype, torch.dtype), "output 'dtype' must be a torch.dtype type" - assert dtype_amp is None or isinstance( - dtype_amp, torch.dtype - ), "output 'dtype_amp' must be either None or torch.dtype type" - existing_values = (*self.outputs[index],) - if isinstance(self.outputs[index], self._OutputDescriptionTyped): - existing_values = (*existing_values[:-2],) - self.outputs[index] = self._OutputDescriptionTyped(*existing_values, dtype, dtype_amp) - - @property - def gradient_accumulation(self): - return getattr(self, GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME, None) - - @gradient_accumulation.setter - def gradient_accumulation(self, name): - self._add_output_description( - self, name, [1], False, torch.bool, None, GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME, ignore_duplicate=True - ) - - @property - def all_finite(self): - return getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None) - - @all_finite.setter - def all_finite(self, name): - self._add_output_description( - self, name, [1], False, torch.bool, None, ALL_FINITE_IO_DESCRIPTION_NAME, ignore_duplicate=True - ) - - @property - def loss_scale_input(self): - return getattr(self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None) - - @loss_scale_input.setter - def loss_scale_input(self, name): - self._add_input_description( - self, name, [], torch.float32, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, ignore_duplicate=True - ) - - def _add_input_description(self, node, name, shape, dtype=None, attr_name=None, ignore_duplicate=False): - """Add a new input description into the node object - - If 'dtype' is specified, a typed input description namedtuple(name, shape, dtype) is created. - Otherwise an untyped input description namedtuple(name, shape) is created instead. - - Args: - node (list or object): node to append input description to. When 'node' is 'self.inputs', - a new input description is appended to the list. - Otherwise, a new input description is created as an attribute into 'node' with name 'attr_name' - name (str): name of input description - shape (list): shape of input description - dtype (torch.dtype): input data type - attr_name (str, default is None): friendly name to allow direct access to the output description - ignore_duplicate (bool, default is False): silently skips addition of duplicate inputs - """ - - assert isinstance(name, str) and len(name) > 0, "'name' is an invalid input name" - not_found = True - if not ignore_duplicate: - if id(node) == id(self.inputs): - not_found = all([name not in i_desc.name for i_desc in node]) - assert not_found, f"'name' {name} already exists in the inputs description" - else: - not_found = attr_name not in dir(self) - assert not_found, f"'attr_name' {attr_name} already exists in the 'node'" - elif not not_found: - return - assert isinstance(shape, list) and all( - [(isinstance(dim, int) or (isinstance(dim, str) and len(dim) > 0)) for dim in shape] - ), "'shape' must be a list of int or str with length at least 1" - assert dtype is None or isinstance(dtype, torch.dtype), "'dtype' must be either None or a torch.dtype type" - if dtype: - new_input_desc = self._InputDescriptionTyped(name, shape, dtype) - else: - new_input_desc = self._InputDescription(name, shape) - - if id(node) == id(self.inputs): - self.inputs.append(new_input_desc) - else: - assert isinstance(attr_name, str) and len(attr_name) > 0, "Invalid 'attr_name'" - setattr(node, attr_name, new_input_desc) - - def _add_output_description( - self, node, name, shape, is_loss, dtype=None, dtype_amp=None, attr_name=None, ignore_duplicate=False - ): - """Add a new output description into the node object as a tuple - - When (name, shape, is_loss, dtype) is specified, a typed output description is created - Otherwise an untyped output description (name, shape, is_loss) is created instead - - Args: - node (list or object): node to append output description to. When 'node' is 'self.outputs', - a new output description is appended to the list. - Otherwise, a new output description is created as an attribute into 'node' with name 'attr_name' - name (str): name of output description - shape (list): shape of output description - is_loss (bool): specifies whether this output is a loss - dtype (torch.dtype): input data type - dtype_amp (torch.dtype, default is None): input data type for evaluation with mixed precision. - attr_name (str, default is None): friendly name to allow direct access to the output description - ignore_duplicate (bool, default is False): silently skips addition of duplicate outputs - """ - - assert isinstance(name, str) and len(name) > 0, "'name' is an invalid output name" - assert isinstance(shape, list) and all( - [(isinstance(dim, int) or (isinstance(dim, str) and len(dim) > 0)) for dim in shape] - ), "'shape' must be a list of int or str with length at least 1" - assert isinstance(is_loss, bool), "'is_loss' must be a bool" - - not_found = True - if not ignore_duplicate: - if id(node) == id(self.outputs): - not_found = all([name not in o_desc.name for o_desc in node]) - assert not_found, f"'name' {name} already exists in the outputs description" - assert ( - all([not o_desc.is_loss for o_desc in node]) if is_loss else True - ), "Only one 'is_loss' is supported at outputs description" - else: - not_found = attr_name not in dir(self) - assert not_found, f"'attr_name' {attr_name} already exists in the 'node'" - elif not not_found: - return - - assert dtype is None or isinstance(dtype, torch.dtype), "'dtype' must be either None or a torch.dtype type" - if dtype: - new_output_desc = self._OutputDescriptionTyped(name, shape, is_loss, dtype, None) - else: - new_output_desc = self._OutputDescription(name, shape, is_loss) - - if id(node) == id(self.outputs): - self.outputs.append(new_output_desc) - else: - assert isinstance(attr_name, str) and len(attr_name) > 0, "Invalid 'attr_name'" - setattr(node, attr_name, new_output_desc) - - def _wrap(self, v): - """Add 'v' as self's attribute to allow direct access as self.v""" - if isinstance(v, (list)): - return type(v)([self._wrap(v) for v in v]) - elif isinstance( - v, - ( - self._InputDescription, - self._InputDescriptionTyped, - self._OutputDescription, - self._OutputDescriptionTyped, - ), - ): - return v - elif isinstance(v, (tuple)): - return type(v)([self._wrap(v) for v in v]) - elif isinstance(v, (dict, int, float, bool, str)): - return _ORTTrainerModelDescInternal(self._main_class_name, v) if isinstance(v, dict) else v - else: - raise ValueError( - f"Unsupported type for model_desc ({v})." - "Only int, float, bool, str, list, tuple and dict are supported" - ) - - -class _ORTTrainerModelDescInternal(_ORTTrainerModelDesc): - r"""Internal class used by ONNX Runtime training backend for input validation - - NOTE: Users MUST NOT use this class in any way! - """ - - def __init__(self, main_class_name, model_desc): - # Used for logging purposes - self._main_class_name = main_class_name - - # Convert dict in object - for k, v in dict(model_desc).items(): - setattr(self, k, self._wrap(v)) - - -def _model_desc_inputs_validation(field, value, error): - r"""Cerberus custom check method for 'model_desc.inputs' - - 'model_desc.inputs' is a list of tuples. - The list has variable length, but each tuple has size 2 - - The first element of the tuple is a string which represents the input name - The second element is a list of shapes. Each shape must be either an int or string. - Empty list represents a scalar output - - Validation is done within each tuple to enforce the schema described above. - - Example: - - .. code-block:: python - - model_desc['inputs'] = [('input1', ['batch', 1024]), - ('input2', []) - ('input3', [512])] - """ - - if not isinstance(value, tuple) or len(value) != 2: - error(field, "must be a tuple with size 2") - if not isinstance(value[0], str): - error(field, "the first element of the tuple (aka name) must be a string") - if not isinstance(value[1], list): - error(field, "the second element of the tuple (aka shape) must be a list") - else: - for shape in value[1]: - if not isinstance(shape, str) and not isinstance(shape, int) or isinstance(shape, bool): - error(field, "each shape must be either a string or integer") - - -@static_vars(loss_counter=0) -def _model_desc_outputs_validation(field, value, error): - r"""Cerberus custom check method for 'model_desc.outputs' - - 'model_desc.outputs' is a list of tuples with variable length. - The first element of the tuple is a string which represents the output name - The second element is a list of shapes. Each shape must be either an int or string. - Empty list represents a scalar output - The third element is optional and is a flag that signals whether the output is a loss value - - Validation is done within each tuple to enforce the schema described above, but also - throughout the list of tuples to ensure a single 'is_loss=True' occurrence. - - Example: - - .. code-block:: python - - model_desc['outputs'] = [('output1', ['batch', 1024], is_loss=True), - ('output2', [], is_loss=False) - ('output3', [512])] - """ - - if not isinstance(value, tuple) or len(value) < 2 or len(value) > 3: - error(field, "must be a tuple with size 2 or 3") - if len(value) == 3 and not isinstance(value[2], bool): - error(field, "the third element of the tuple (aka is_loss) must be a boolean") - elif len(value) == 3: - if value[2]: - _model_desc_outputs_validation.loss_counter += 1 - if _model_desc_outputs_validation.loss_counter > 1: - error(field, "only one is_loss can bet set to True") - if not isinstance(value[0], str): - error(field, "the first element of the tuple (aka name) must be a string") - if not isinstance(value[1], list): - error(field, "the second element of the tuple (aka shape) must be a list") - else: - for shape in value[1]: - if not isinstance(shape, str) and not isinstance(shape, int) or isinstance(shape, bool): - error(field, "each shape must be either a string or integer") - - -# Validation schema for model description dictionary -MODEL_DESC_SCHEMA = { - "inputs": { - "type": "list", - "required": True, - "minlength": 1, - "schema": {"check_with": _model_desc_inputs_validation}, - }, - "outputs": { - "type": "list", - "required": True, - "minlength": 1, - "schema": {"check_with": _model_desc_outputs_validation}, - }, -} diff --git a/orttraining/orttraining/python/training/ort_triton/_codegen.py b/orttraining/orttraining/python/training/ort_triton/_codegen.py index cac9b6fc4a2b6..462491365c1fa 100644 --- a/orttraining/orttraining/python/training/ort_triton/_codegen.py +++ b/orttraining/orttraining/python/training/ort_triton/_codegen.py @@ -52,7 +52,8 @@ def codegen(self, node: IRNode, context: CodegenContext, code_buffer: CodeBuffer def _get_elementwise_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str) -> Tuple[str, str]: if offset_calc.is_x_reduced(arg_name): - return "", "" + # Scalar. + return "tl.full([1], 0, tl.int32)", "" if offset_calc.is_same_x_shape(arg_name): return "xindex", "xmask" if offset_calc.requires_x_mask else "" strides = offset_calc.get_input_strides(arg_name) @@ -88,13 +89,16 @@ def _get_reduce_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str) if offset_calc.requires_r_mask: mask_strs.append("rmask") + # If both is_x_reduced and is_r_reduced are True, it's scalar. + if len(offset_strs) == 0: + offset_strs.append("tl.full([1, 1], 0, tl.int32)") return " + ".join(offset_strs), " & ".join(mask_strs) - def _get_offset_mask(self, node: OffsetCalculator, arg_name: str) -> Tuple[str, str]: + def _get_offset_mask(self, offset_calc: OffsetCalculator, arg_name: str) -> Tuple[str, str]: return ( - self._get_reduce_offset_mask(node, arg_name) - if node.is_reduction - else self._get_elementwise_offset_mask(node, arg_name) + self._get_reduce_offset_mask(offset_calc, arg_name) + if offset_calc.is_reduction + else self._get_elementwise_offset_mask(offset_calc, arg_name) ) def IONode(self, node: IONode, context: CodegenContext, code_buffer: CodeBuffer, indent: int): # noqa: N802 diff --git a/orttraining/orttraining/python/training/ort_triton/_ir.py b/orttraining/orttraining/python/training/ort_triton/_ir.py index 628ea822ff55b..50121cbf49804 100644 --- a/orttraining/orttraining/python/training/ort_triton/_ir.py +++ b/orttraining/orttraining/python/training/ort_triton/_ir.py @@ -131,7 +131,7 @@ def register_tensor_arg(self, tensor_arg: TensorArg): input_shape = tensor_arg.shape if tensor_arg.name in self.reduced_args: assert self.is_reduction - reduced_rank = len(input_shape) - len(self.reduce_axes) + reduced_rank = len(self.target_shape) - len(self.reduce_axes) if len(input_shape) < reduced_rank: input_shape = [sympy.Integer(1)] * (reduced_rank - len(input_shape)) + input_shape input_shape = ( @@ -143,7 +143,9 @@ def register_tensor_arg(self, tensor_arg: TensorArg): input_shape = [sympy.Integer(1)] * (len(self.target_shape) - len(input_shape)) + input_shape running_stride = sympy.Integer(1) for i in range(len(self.target_shape) - 1, -1, -1): - if self.target_shape[i] == input_shape[i]: + if self.target_shape[i] == input_shape[i] and not ( + tensor_arg.name in self.reduced_args and i in self.reduce_axes + ): strides.insert(0, running_stride) running_stride = running_stride * input_shape[i] else: diff --git a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py index 8a642a1f7f26c..1fe61750e651e 100644 --- a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py +++ b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py @@ -87,7 +87,8 @@ def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str def get_config() -> str: """ Get the supported ops and other configs in JSON format to control the Triton fusion on backend side. - All supported ops are from _op_config.py. The Triton fusion will try to fuse subgraphs with connected supported ops. + All supported ops are from user config specified by env ORTMODULE_TRITON_CONFIG_FILE or from _op_config.py. + The Triton fusion will try to fuse subgraphs with connected supported ops. The initializer value can be "none", "scalar", and "all". "none": no initializer will be added to subgraphs. "scalar": only related scalar initializers will be added to subgraphs. @@ -95,6 +96,11 @@ def get_config() -> str: The min_nodes is used to control the minimum number of non-no-op nodes in a subgraph. """ + config_file = os.getenv("ORTMODULE_TRITON_CONFIG_FILE", "") + if config_file and os.path.exists(config_file): + with open(config_file, encoding="UTF-8") as f: + return f.read() + config = {"ops": get_supported_ops(), "initializer": "scalar", "min_nodes": 2} return json.dumps(config) diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py deleted file mode 100644 index d5a488c436a1d..0000000000000 --- a/orttraining/orttraining/python/training/orttrainer.py +++ /dev/null @@ -1,1537 +0,0 @@ -import copy -import io -import os -import warnings -from functools import partial -from inspect import signature - -import numpy as np -import onnx -import torch - -import onnxruntime as ort -from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference - -from . import _checkpoint_storage, _utils, amp, checkpoint, optim, postprocess -from .model_desc_validation import _ORTTrainerModelDesc -from .orttrainer_options import ORTTrainerOptions - - -class TrainStepInfo: - r"""Private class used to store runtime information from current train step. - - After every train step, :py:meth:`ORTTrainer.train_step` updates the internal instance of - :py:class:`.TrainStepInfo` residing on :py:class:`.ORTTrainer` with relevant information - from the forward pass. - - This class shouldn't be accessed directly by the user, unless they really know what they are doing. - Instead, :py:class:`.ORTTrainer` passes it to relevant class methods automatically, - such as :py:method:`._LRScheduler.get_lr` or :py:class:`.LossScaler.update`. - - Args: - optimizer_config (optim._OptimizerConfig): reference to optimizer config - all_finite (bool, default is True): flag that indicates whether all gradients are still finite after last step - fetches (list of str, default is []): list of output names to fetch from train_step/eval_step. Set it to [] to reset normal behavior. - optimization_step (int): indicates the number of optimizations performed. Used for learning rate scheduling - step (int): indicates current training step. Used for gradient accumulation - - Example: - - .. code-block:: python - - info = TrainStepInfo(optimizer_config=optim.SGDConfig(lr=0.01)) - if info.all_finite: - print(f'Yay, all gradients are finite at {step} step!') - - """ - - def __init__(self, optimizer_config, all_finite=True, fetches=[], optimization_step=0, step=0): # noqa: B006 - assert isinstance(optimizer_config, optim._OptimizerConfig), "optimizer_config must be a optim._OptimizerConfig" - assert isinstance(all_finite, bool), "all_finite must be a bool" - assert isinstance(fetches, list) and all( - [isinstance(item, str) for item in fetches] - ), "fetches must be a list of str" - assert isinstance(optimization_step, int) and optimization_step >= 0, "optimization_step must be a positive int" - assert isinstance(step, int) and step >= 0, "step must be a positive int" - - self.optimizer_config = optimizer_config - self.all_finite = all_finite - self.fetches = fetches - self.optimization_step = optimization_step - self.step = step - - -class ORTTrainer: - r"""Pytorch frontend for ONNX Runtime training - - Entry point that exposes the C++ backend of ORT as a Pytorch frontend. - - Args: - model (torch.nn.Module or onnx.ModelProto): either a PyTorch or ONNX model. - When a PyTorch model and :py:attr:`loss_fn` are specified, :py:attr:`model` and :py:obj:`loss_fn` are combined. - When a ONNX model is provided, the loss is identified by the flag :py:obj:`is_loss=True` in one of the :py:attr:`.model_desc.outputs` entries. - model_desc (dict): model input and output description. - This is used to identify inputs and outputs and their shapes, so that ORT can generate back propagation graph, plan memory allocation for - training, and perform optimizations. - :py:attr:`model_desc` must be consistent with the training :py:attr:`model` and have the following (:py:obj:`dict`) schema - :py:obj:`{ 'inputs': [tuple(name, shape)], 'outputs': [tuple(name, shape, is_loss)]}`. - :py:attr:`name` is a string representing the name of input or output of the model. - For :py:obj:`model_desc['inputs']` entries, :py:attr:`name` must match input names of the original PyTorch model's :py:meth:`torch.nn.Module.forward` method. - For ONNX models, both name and order of input names must match. - For :py:obj:`model_desc['outputs']` entries, the order must match the original PyTorch's output as returned by :py:meth:`torch.nn.Module.forward` method. - For ONNX models, both name and order of output names must match. - :py:attr:`shape` is a list of string or integers that describes the shape of the input/output. - Each dimension size can be either a string or an int. String means the dimension size is dynamic, while integers mean static dimensions. - An empty list implies a scalar. - Lastly, :py:attr:`is_loss` is a boolean (default is False) that flags if this output is considered a loss. - ORT backend needs to know which output is loss in order to generate back propagation graph. - Loss output must be specified when either :py:attr:`loss_fn` is specified or when loss is embedded in the model. - Note that only one loss output is supported per model. - optimizer_config (optim._OptimizerConfig): optimizer config. - One of :py:class:`.optim.AdamConfig`, :py:class:`.optim.LambConfig` or :py:class:`.optim.SGDConfig`. - loss_fn (callable, default is None): a PyTorch loss function. - It takes two inputs [prediction, label] and outputs a scalar loss tensor. - If provided, :py:attr:`loss_fn` is combined with the PyTorch :py:attr:`model` to form a combined PyTorch model. - Inputs to the combined PyTorch model are concatenation of the :py:attr:`model`'s input and :py:attr:`loss_fn`'s label input. - Outputs of the combined PyTorch model are concatenation of :py:attr:`loss_fn`'s loss output and :py:attr:`model`'s outputs. - options (ORTTrainerOptions, default is None): options for additional features. - Example: - - .. code-block:: python - - model = ... - loss_fn = ... - model_desc = { - "inputs": [ - ("input_ids", ["batch", "max_seq_len_in_batch"]), - ("attention_mask", ["batch", "max_seq_len_in_batch"]), - ("token_type_ids", ["batch", "max_seq_len_in_batch"]), - ("masked_lm_labels", ["batch", "max_seq_len_in_batch"]), - ("next_sentence_label", ["batch", 1]) - ], - "outputs": [ - ("loss", [], True), - ], - } - optim_config = optim.LambConfig(param_groups = [ { 'params' : ['model_param0'], 'alpha' : 0.8, 'beta' : 0.7}, - { 'params' : ['model_param1' , 'model_param_2'], 'alpha' : 0.0} - ], - alpha=0.9, beta=0.999) - ort_trainer = ORTTrainer(model, model_desc, optim_config, loss_fn) - """ - - def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None): - warnings.warn( - "ORTTrainer is deprecated and will be removed in ort release 1.14. Please use ORTModule instead.", - FutureWarning, - ) - - assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model" - assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'" - assert isinstance( - optim_config, optim._OptimizerConfig - ), "'optim_config' is required and must be any of 'AdamConfig', 'LambConfig' or 'SGDConfig'" - assert loss_fn is None or ( - callable(loss_fn) and len(signature(loss_fn).parameters) == 2 - ), "'loss_fn' must be either 'None' or a callable with two parameters" - assert options is None or isinstance( - options, ORTTrainerOptions - ), "'options' must be either 'None' or 'ORTTrainerOptions'" - - # Model + Loss validation - # Supported combinarios are - # ---------------------------------------- - # | | Model | Loss | - # ---------------------------------------- - # | 1 | torch.nn.Module | None | - # | 2 | torch.nn.Module | torch.nn.Module | - # | 3 | ONNX | None | - # ---------------------------------------- - self._torch_model = None - self._onnx_model = None - if isinstance(model, torch.nn.Module): - assert loss_fn is None or isinstance( - model, torch.nn.Module - ), "'loss_fn' must be either 'None' or 'torch.nn.Module'" - self._torch_model = model - self.loss_fn = loss_fn - # TODO: Remove when experimental checkpoint functions are removed. - self._torch_state_dict_keys = list(model.state_dict().keys()) - elif isinstance(model, onnx.ModelProto): - assert loss_fn is None, "'loss_fn' must not be specified when 'model' is an ONNX model" - self._onnx_model = model - self.loss_fn = None - else: - raise ValueError("'model' must be either 'torch.nn.Module' or 'onnx.ModelProto'") - - self.model_desc = _ORTTrainerModelDesc(model_desc) - self.optim_config = optim_config - - # ORTTrainerOptions - if not options: - options = ORTTrainerOptions() - self.options = options - if self.options.mixed_precision.enabled and not self.options.mixed_precision.loss_scaler: - # TODO: Move this to model_desc_validation.py - self.options.mixed_precision.loss_scaler = amp.loss_scaler.DynamicLossScaler() - # Post processing ONNX model given as input - if self._onnx_model: - if self.options._internal_use.enable_internal_postprocess: - self._onnx_model = postprocess.run_postprocess(self._onnx_model) - if self.options._internal_use.extra_postprocess: - self._onnx_model = self.options._internal_use.extra_postprocess(self._onnx_model) - assert isinstance(self._onnx_model, onnx.ModelProto), "'extra_postprocess' must return a ONNX model" - - # When input model is already ONNX (and not exported from Pytorch within ORTTrainer), - # append 'dtype' from ONNX into model description's - for idx_i, i_desc in enumerate(self.model_desc.inputs): - dtype = None - for onnx_input in self._onnx_model.graph.input: - if onnx_input.name == i_desc.name: - dtype = _utils.dtype_onnx_to_torch(onnx_input.type.tensor_type.elem_type) - self.model_desc.add_type_to_input_description(idx_i, dtype) - break - assert dtype is not None, f"ONNX model with unknown input type ({i_desc.name})" - for idx_o, o_desc in enumerate(self.model_desc.outputs): - dtype = None - for onnx_output in self._onnx_model.graph.output: - if onnx_output.name == o_desc.name: - dtype = _utils.dtype_onnx_to_torch(onnx_output.type.tensor_type.elem_type) - self.model_desc.add_type_to_output_description(idx_o, dtype) - break - assert dtype is not None, f"ONNX model with unknown output type ({o_desc.name})" - - try: - from torch.utils.cpp_extension import ROCM_HOME - - self.is_rocm_pytorch = bool(torch.version.hip is not None and ROCM_HOME is not None) - except ImportError: - self.is_rocm_pytorch = False - - # TODO: Remove when experimental checkpoint functions are removed. - self._state_dict = {} - - self._train_step_info = TrainStepInfo(self.optim_config) - self._training_session = None - self._load_state_dict = None - self._init_session( - provider_options=self.options._validated_opts["provider_options"], - session_options=self.options.session_options, - ) - - def eval_step(self, *args, **kwargs): - r"""Evaluation step method - - Args: - *args: Arbitrary arguments that are used as model input (data only) - **kwargs: Arbitrary keyword arguments that are used as model input (data only) - - Returns: - ordered :py:obj:`list` with model outputs as described by :py:attr:`.ORTTrainer.model_desc` - """ - # Get data. CombineTorchModelLossFn takes label as last input and outputs loss first - sample_input = self._prepare_model_input(self.model_desc.inputs, None, None, *args, **kwargs) - - # Export model to ONNX - if self._onnx_model is None: - if self._torch_model is not None: - self._init_onnx_model(sample_input) - else: - raise RuntimeError("Model is uninitialized. Only ONNX and PyTorch models are supported") - - # Prepare input/output description - inputs_desc = self.model_desc.inputs - outputs_desc = self.model_desc.outputs - if self._train_step_info.fetches: - outputs_desc = [o_desc for o_desc in outputs_desc if o_desc.name in self._train_step_info.fetches] - if len(outputs_desc) != len(self._train_step_info.fetches): - raise RuntimeError("The specified fetches list contains invalid output names") - - # Normalize input - if not isinstance(sample_input, (list, tuple)): - sample_input = (sample_input,) - - # RunOptions - run_options = ort.RunOptions() - run_options.only_execute_path_to_fetches = True - run_options.training_mode = False - - # Run a eval step and return - session_run_results = self._training_session_run_helper( - False, sample_input, inputs_desc, outputs_desc, run_options - ) - - # Output must be returned in the same order as defined in the model description - results = [session_run_results[o_desc.name] for o_desc in outputs_desc] - return results[0] if len(results) == 1 else results - - def save_as_onnx(self, path): - r"""Persists ONNX model into :py:attr:`path` - - The model will be saved as a Google Protocol Buffers (aka protobuf) file as per ONNX standard. - The graph includes full information, including inference and training metadata. - - Args: - path (str): Full path, including filename, to save the ONNX model in the filesystem - - Raises: - RuntimeWarning: raised when neither `train_step` or `eval_step` was called at least once - ValueError: raised when `path` is not valid path - """ - if not self._training_session: - warnings.warn( - "Training session is not initialized yet. " - "'train_step' or 'eval_step' methods must be executed at least once before calling 'save_as_onnx()'." - ) - return - state_tensors = self._training_session.get_state() - self._update_onnx_model_initializers(state_tensors) - - assert isinstance(path, str), "'path' must be a valid path string" - dir_name = os.path.dirname(path) - file_name = os.path.basename(path) - if (dir_name and not os.path.exists(dir_name)) or not file_name: - warnings.warn("'path' is not valid or does not exist") - return - - with open(path, "wb") as f: - f.write(self._onnx_model.SerializeToString()) - - def _check_model_export(self, input): - from numpy.testing import assert_allclose - from onnx import TensorProto, helper, numpy_helper # noqa: F401 - - onnx_model_copy = copy.deepcopy(self._onnx_model) - - # Mute the dropout nodes - dropout_nodes = [n for n in onnx_model_copy.graph.node if n.op_type == "Dropout"] - for node in dropout_nodes: - ratio_node = next(n for n in onnx_model_copy.graph.node if node.input[1] in n.output) - training_mode_node = next(n for n in onnx_model_copy.graph.node if node.input[2] in n.output) - - training_mode_node.attribute.pop() - ratio_node.attribute.pop() - new_training_mode_arr = np.array(False, dtype=bool) - new_ratio_arr = np.array(0.0, dtype=np.float32) - new_training_mode = numpy_helper.from_array(new_training_mode_arr) - new_ratio = numpy_helper.from_array(new_ratio_arr) - training_mode_node.attribute.add().t.CopyFrom(new_training_mode) - ratio_node.attribute.add().t.CopyFrom(new_ratio) - training_mode_node.attribute[0].type = 4 - ratio_node.attribute[0].type = 4 - training_mode_node.attribute[0].name = "value" - ratio_node.attribute[0].name = "value" - - _inference_sess = ort.InferenceSession( - onnx_model_copy.SerializeToString(), providers=ort.get_available_providers() - ) - inf_inputs = {} - for i, input_elem in enumerate(input): - inf_inputs[_inference_sess.get_inputs()[i].name] = input_elem.cpu().numpy() - _inference_outs = _inference_sess.run(None, inf_inputs) - for torch_item, ort_item in zip(self.torch_sample_outputs, _inference_outs): - assert_allclose( - torch_item, - ort_item, - rtol=1e-2, - atol=1e-6, - err_msg="Mismatch between outputs of PyTorch model and exported ONNX model. " - "Note that different backends may exhibit small computational differences." - "If this is within acceptable margin, or if there is random generator " - "in the model causing inevitable mismatch, you can proceed training by " - "setting the flag debug.check_model_export to False.", - ) - - def train_step(self, *args, **kwargs): - r"""Train step method - - After forward pass, an ordered list with all outputs described at :py:attr:`ORTTrainer.model_desc` is returned. - Additional information relevant to the train step is maintend by :py:attr:`ORTTrainer._train_step_info`. - See :py:class:`.TrainStepInfo` for details. - - Args: - *args: Arbitrary arguments that are used as model input (data only) - **kwargs: Arbitrary keyword arguments that are used as model input (data only) - - Returns: - ordered :py:obj:`list` with model outputs as described by :py:attr:`ORTTrainer.model_desc` - """ - # Export model to ONNX - if self._onnx_model is None: - sample_input = self._prepare_model_input(self.model_desc.inputs, None, None, *args, **kwargs) - self._init_onnx_model(sample_input) - - # Debug Model Export if indicated - if self.options.debug.check_model_export: - self._check_model_export(sample_input) - - # Prepare inputs+lr and output descriptions - inputs_desc = self._model_desc_inputs_with_lr - outputs_desc = self.model_desc.outputs - - # Train step must be incremented *before* gradient accumulation code - # Gradients are accumulated when - # self._train_step_info.step % self.options.batch.gradient_accumulation_steps != 0, - # and they are updated otherwise - self._train_step_info.step += 1 - - # RunOptions - run_options = None - mixed_precision_without_fetches = False - if self._train_step_info.fetches: - outputs_desc = [o_desc for o_desc in outputs_desc if o_desc.name in self._train_step_info.fetches] - if len(outputs_desc) != len(self._train_step_info.fetches): - raise RuntimeError("The specified fetches list contains invalid output names") - elif self._train_step_info.step % self.options.batch.gradient_accumulation_steps != 0: - run_options = ort.RunOptions() - run_options.only_execute_path_to_fetches = True - outputs_desc = self._model_desc_outputs_with_gradient_accumulation - elif self.options.mixed_precision.enabled: - mixed_precision_without_fetches = True - outputs_desc = self._model_desc_outputs_with_all_finite - - # Update Learning Rate if Necessary - lr = self.optim_config.lr - if self.options.lr_scheduler: - lr = self.options.lr_scheduler._step(self._train_step_info)[0] - - # Loss Scale for mixed precision - loss_scale = None - if self.options.mixed_precision.enabled: - loss_scaler = self.options.mixed_precision.loss_scaler - assert loss_scaler, "Loss scaler is required when mixed precision is enabled" - loss_scale = loss_scaler.loss_scale - inputs_desc = self._model_desc_inputs_with_lr_and_loss_scale - - # Get data. CombineTorchModelLossFn takes label as last input and outputs loss first - input = self._prepare_model_input(inputs_desc, lr, loss_scale, *args, **kwargs) - - # Normalize input - if not isinstance(args, (list, tuple)): - args = (args,) - - # Run a train step and return - session_run_results = self._training_session_run_helper(True, input, inputs_desc, outputs_desc, run_options) - if mixed_precision_without_fetches: - # After session run with all_fp32_gradients_finite, we need to clear the training I/O binding's output - # Otherwise next run with only_execute_path_to_fetches will lead to gradient all reduce - # because all_fp32_gradients_finite is still in the feed. - self._train_io_binding.clear_binding_outputs() - - is_all_finite = session_run_results[self.model_desc.all_finite.name] - self._train_step_info.all_finite = is_all_finite - if loss_scaler: - loss_scaler.update(self._train_step_info) - if is_all_finite: - # Optimization step must be incremented *after* optimization is successful - self._train_step_info.optimization_step += 1 - elif self._train_step_info.step % self.options.batch.gradient_accumulation_steps == 0: - # Optimization step must be incremented *after* optimization is successful - self._train_step_info.optimization_step += 1 - - # Output must be returned in the same order as defined in the model description - # or in the order specified by TrainStepInfo.fetches, if applicable - if self._train_step_info.fetches: - results = [session_run_results[o_desc] for o_desc in self._train_step_info.fetches] - else: - results = [session_run_results[o_desc.name] for o_desc in self.model_desc.outputs] - return results[0] if len(results) == 1 else results - - def _convert_torch_model_loss_fn_to_onnx(self, inputs, device): - # Dynamic axes - dynamic_axes = {} - for input in self.model_desc.inputs: - symbolic_axis = {} - for i, axis in enumerate(input.shape): - if isinstance(axis, str): - symbolic_axis[i] = axis - if len(symbolic_axis): - dynamic_axes[input.name] = symbolic_axis - for output in self.model_desc.outputs: - symbolic_axis = {} - for i, axis in enumerate(output.shape): - if isinstance(axis, str): - symbolic_axis[i] = axis - if len(symbolic_axis): - dynamic_axes[output.name] = symbolic_axis - - if isinstance(inputs, torch.Tensor): - inputs = [inputs] - if isinstance(inputs, dict): - sample_inputs = [inputs[k.name_].to(device=device) for k in self.model_desc.inputs] - elif isinstance(inputs, (list, tuple)): - sample_inputs = [ - input.to(device=device) for i, input in enumerate(inputs) if i < len(self.model_desc.inputs) - ] - else: - raise RuntimeError( - "Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported." - ) - - # PyTorch ONNX exporter does not match argument names - # This is an issue because the ONNX graph depends on all inputs to be specified - - # Validate loss_fn - if self.loss_fn: - sig_loss = signature(self.loss_fn) - if len(sig_loss.parameters) != 2: - raise RuntimeError("loss function should take two arguments - predict and label.") - - # Basic input names from model - input_names = [input.name for input in self.model_desc.inputs] - sig = signature(self._torch_model.forward) - ordered_input_list = list(sig.parameters.keys()) - - # Label from loss_fn goes after model input - if self.loss_fn: - ordered_input_list = [*ordered_input_list, list(sig_loss.parameters.keys())[1]] - - class CombineTorchModelLossFnWrapInput(torch.nn.Module): - def __init__(self, model, loss_fn, input_names): - super().__init__() - self.model = model - self.loss_fn = loss_fn - self.input_names = input_names - - def forward(self, *inputs): - sig = signature(self.model.forward) - - input_dict = {} - for key in sig.parameters: - if key in self.input_names: - input_dict[key] = inputs[self.input_names.index(key)] - - model_out = self.model(**input_dict) - if self.loss_fn is None: - return model_out - - label = inputs[-1] - preds = model_out - return self.loss_fn(preds, label), preds - - model = CombineTorchModelLossFnWrapInput(self._torch_model, self.loss_fn, input_names) - - # Do an inference to grab output types - model.eval() - with torch.no_grad(): - # Deepcopy inputs, since input values may change after model run. - sample_inputs_copy = copy.deepcopy(sample_inputs) - try: - # Deepcopy model, in case model is stateful and changes after model run. - model_copy = copy.deepcopy(model) - except Exception: - model_copy = model - warnings.warn( - "This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX." - " Compute will continue, but unexpected results may occur!" - ) - sample_outputs = model_copy(*sample_inputs_copy) - self.torch_sample_outputs = sample_outputs - model.train() - - if isinstance(sample_outputs, torch.Tensor): - sample_outputs = [sample_outputs] - - # Append 'dtype' for model description's inputs/outputs - for idx_i, sample_input in enumerate(sample_inputs): - if idx_i < len(self.model_desc.inputs): - self.model_desc.add_type_to_input_description(idx_i, sample_input.dtype) - for idx_o, sample_output in enumerate(sample_outputs): - if idx_o < len(self.model_desc.outputs): - self.model_desc.add_type_to_output_description(idx_o, sample_output.dtype) - - # Export the model to ONNX - f = io.BytesIO() - - # Deepcopy inputs, since input values may change after model run. - sample_inputs_copy = copy.deepcopy(sample_inputs) - - # Handle contrib OPs support - from onnxruntime.tools import pytorch_export_contrib_ops - - if self.options._internal_use.enable_onnx_contrib_ops: - pytorch_export_contrib_ops.register() - else: - # Unregister in case they were registered in previous calls. - pytorch_export_contrib_ops.unregister() - - # Export torch.nn.Module to ONNX - torch.onnx.export( - model, - tuple(sample_inputs_copy), - f, - input_names=[input.name for input in self.model_desc.inputs], - output_names=[output.name for output in self.model_desc.outputs], - opset_version=self.options._internal_use.onnx_opset_version, - dynamic_axes=dynamic_axes, - do_constant_folding=False, - training=torch.onnx.TrainingMode.TRAINING, - ) - onnx_model = onnx.load_model_from_string(f.getvalue()) - - # Remove 'model.' prefix introduced by CombineTorchModelLossFn class - if isinstance(model, CombineTorchModelLossFnWrapInput): - replace_name_dict = {} - for n in onnx_model.graph.initializer: - if n.name.startswith("model."): - replace_name_dict[n.name] = n.name[len("model.") :] - n.name = replace_name_dict[n.name] - for n in onnx_model.graph.node: - for i, name in enumerate(n.input): - if name in replace_name_dict: - n.input[i] = replace_name_dict[name] - - return onnx_model - - def _create_ort_training_session(self, optimizer_state_dict=None, session_options=None, provider_options=None): - if optimizer_state_dict is None: - optimizer_state_dict = {} - # Validating frozen_weights names - unused_frozen_weights = [ - n - for n in self.options.utils.frozen_weights - if n not in [i.name for i in self._onnx_model.graph.initializer] - ] - if unused_frozen_weights: - raise RuntimeError(f"{unused_frozen_weights} params from 'frozen_weights' not found in the ONNX model.") - - # Get loss name from model description - loss_name = [item.name for item in self.model_desc.outputs if item.is_loss] - assert len(loss_name) == 1, f"Only one loss output is supported ({len(loss_name)} were specified)" - loss_name = loss_name[0] - - # Parse optimizer parameters - optimizer_attributes_map = {} - optimizer_int_attributes_map = {} - trainable_params = set() - for initializer in self._onnx_model.graph.initializer: - if initializer.name in self.options.utils.frozen_weights: - continue # only trainable parameters are passed to the backend - trainable_params.add(initializer.name) - optimizer_attributes_map[initializer.name] = {} - optimizer_int_attributes_map[initializer.name] = {} - not_in_param_groups = True - for param_group in self.optim_config.params: - if initializer.name not in param_group["params"]: - continue # keep looking for a matching param_group - not_in_param_groups = False - for k, v in param_group.items(): - # 'params' is not a hyper parameter, skip it. 'lr' per weight is not supported - if k == "params" or k == "lr": - continue - if isinstance(v, float): - optimizer_attributes_map[initializer.name][k] = v - elif isinstance(v, int): - optimizer_int_attributes_map[initializer.name][k] = v - else: - raise ValueError("Optimizer attributes must be either float or int.") - - # set default values for params not found in groups - if not_in_param_groups: - for k, v in self.optim_config.defaults.items(): - if k == "lr": - continue - if isinstance(v, float): - optimizer_attributes_map[initializer.name][k] = v - elif isinstance(v, int): - optimizer_int_attributes_map[initializer.name][k] = v - else: - raise ValueError("Optimizer attributes must be either float or int.") - - self.options.distributed.horizontal_parallel_size = max(self.options.distributed.horizontal_parallel_size, 1) - self.options.distributed.data_parallel_size = ( - self.options.distributed.world_size // self.options.distributed.horizontal_parallel_size - ) - - # TrainingParameters - ort_parameters = ort.TrainingParameters() - ort_parameters.loss_output_name = loss_name - ort_parameters.use_mixed_precision = self.options.mixed_precision.enabled - ort_parameters.world_rank = self.options.distributed.world_rank - ort_parameters.world_size = self.options.distributed.world_size - ort_parameters.gradient_accumulation_steps = self.options.batch.gradient_accumulation_steps - ort_parameters.allreduce_post_accumulation = self.options.distributed.allreduce_post_accumulation - ort_parameters.enable_adasum = self.options.distributed.enable_adasum - ort_parameters.deepspeed_zero_stage = self.options.distributed.deepspeed_zero_optimization.stage - ort_parameters.enable_grad_norm_clip = self.options.utils.grad_norm_clip - ort_parameters.set_gradients_as_graph_outputs = False - ort_parameters.use_memory_efficient_gradient = self.options.utils.memory_efficient_gradient - ort_parameters.training_optimizer_name = self.optim_config.name - ort_parameters.lr_params_feed_name = self.model_desc.learning_rate.name - ort_parameters.weights_to_train = trainable_params - ort_parameters.optimizer_attributes_map = optimizer_attributes_map - ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map - if bool(optimizer_state_dict): - ort_parameters.set_optimizer_initial_state(optimizer_state_dict) - - ort_parameters.attn_dropout_recompute = self.options.graph_transformer.attn_dropout_recompute - ort_parameters.gelu_recompute = self.options.graph_transformer.gelu_recompute - ort_parameters.transformer_layer_recompute = self.options.graph_transformer.transformer_layer_recompute - ort_parameters.number_recompute_layers = self.options.graph_transformer.number_recompute_layers - - ort_parameters.data_parallel_size = self.options.distributed.data_parallel_size - ort_parameters.horizontal_parallel_size = self.options.distributed.horizontal_parallel_size - ort_parameters.pipeline_parallel_size = self.options.distributed.pipeline_parallel.pipeline_parallel_size - ort_parameters.num_pipeline_micro_batches = ( - self.options.distributed.pipeline_parallel.num_pipeline_micro_batches - ) - ort_parameters.pipeline_cut_info_string = self.options.distributed.pipeline_parallel.pipeline_cut_info_string - # We have special handling for dictionary-typed option. - # sliced_schema._validated_opts is the original dictionary while sliced_schema is a _ORTTrainerOptionsInternal. - ort_parameters.sliced_schema = self.options.distributed.pipeline_parallel.sliced_schema._validated_opts - # We have special handling for dictionary-typed option. - # sliced_axes._validated_opts is the original dictionary while sliced_schema is a _ORTTrainerOptionsInternal. - ort_parameters.sliced_axes = self.options.distributed.pipeline_parallel.sliced_axes._validated_opts - ort_parameters.sliced_tensor_names = self.options.distributed.pipeline_parallel.sliced_tensor_names - - ort_parameters.model_after_graph_transforms_path = ( - self.options.debug.graph_save_paths.model_after_graph_transforms_path - ) - ort_parameters.model_with_gradient_graph_path = ( - self.options.debug.graph_save_paths.model_with_gradient_graph_path - ) - ort_parameters.model_with_training_graph_path = ( - self.options.debug.graph_save_paths.model_with_training_graph_path - ) - - # SessionOptions - session_options = ort.SessionOptions() if session_options is None else session_options - session_options.use_deterministic_compute = self.options.debug.deterministic_compute - if ( - self.options.graph_transformer.attn_dropout_recompute - or self.options.graph_transformer.gelu_recompute - or self.options.graph_transformer.transformer_layer_recompute - ): - session_options.execution_order = ort.ExecutionOrder.PRIORITY_BASED - if len(self.options.debug.graph_save_paths.model_with_training_graph_after_optimization_path) > 0: - session_options.optimized_model_filepath = ( - self.options.debug.graph_save_paths.model_with_training_graph_after_optimization_path - ) - - # old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error. - # for example, load_state_dict will be called before returing the function, and it calls _init_session again - del self._training_session - - # Set provider-specific options if needed - def get_providers(provider_options): - providers = ort.get_available_providers() - if provider_options: - for provider_name in provider_options: - if provider_name in providers: - providers[providers.index(provider_name)] = (provider_name, provider_options[provider_name]) - else: - providers.insert(0, (provider_name, provider_options[provider_name])) - # default: using cuda - elif "cuda" in self.options.device.id.lower(): - gpu_ep_options = {"device_id": _utils.get_device_index(self.options.device.id)} - gpu_ep_name = "ROCMExecutionProvider" if self.is_rocm_pytorch else "CUDAExecutionProvider" - if self.options.device.mem_limit > 0: - gpu_ep_options["gpu_mem_limit"] = self.options.device.mem_limit - - if gpu_ep_name not in providers: - raise RuntimeError( - "ORTTrainer options specify a CUDA device but the {} provider is unavailable.".format( - cuda_ep_name # noqa: F821 - ) - ) - - providers[providers.index(gpu_ep_name)] = (gpu_ep_name, gpu_ep_options) - - return providers - - # TrainingSession - self._training_session = ort.TrainingSession( - self._onnx_model.SerializeToString(), ort_parameters, session_options, get_providers(provider_options) - ) - - # I/O bindings - self._train_io_binding = self._training_session.io_binding() - self._eval_io_binding = self._training_session.io_binding() - - def _init_onnx_model(self, inputs): - if self._onnx_model is not None: - return - - if self._torch_model is not None: - # PyTorch model is moved to cpu to save GPU memory - self._torch_model.cpu() - - # PyTorch buffers (created using 'register_buffer') shouldn't be trained - torch_buffers = list(dict(self._torch_model.named_buffers()).keys()) - self.options.utils.frozen_weights.extend(torch_buffers) - - # Export to ONNX - self._onnx_model = self._convert_torch_model_loss_fn_to_onnx(inputs, "cpu") - - # Post processing for ONNX models expported from PyTorch - if self.options._internal_use.enable_internal_postprocess: - self._onnx_model = postprocess.run_postprocess(self._onnx_model) - if self.options._internal_use.extra_postprocess: - self._onnx_model = self.options._internal_use.extra_postprocess(self._onnx_model) - - optimizer_state_dict = {} - if self._load_state_dict: - optimizer_state_dict = self._load_state_dict() - - self._init_session( - optimizer_state_dict, - session_options=self.options.session_options, - provider_options=self.options._validated_opts["provider_options"], - ) - - def _init_session(self, optimizer_state_dict={}, session_options=None, provider_options=None): # noqa: B006 - if self._onnx_model is None: - return - - if self.options.utils.run_symbolic_shape_infer: - self._onnx_model = SymbolicShapeInference.infer_shapes( - self._onnx_model, auto_merge=True, guess_output_rank=True - ) - - # Create training session used by train_step - # pass all optimizer states to the backend - self._create_ort_training_session( - optimizer_state_dict, session_options=session_options, provider_options=provider_options - ) - - # Update model description to update dtype when mixed precision is enabled - # C++ backend modifies model's output dtype from float32 to float16 for mixed precision - # Note that for training we must use float32 and for evaluation we must use float16 - for idx, o_desc in enumerate(self.model_desc.outputs): - if ( - self.options.mixed_precision.enabled - and o_desc.dtype == torch.float32 - and not self._training_session.is_output_fp32_node(o_desc.name) - ): - self.model_desc.add_type_to_output_description(idx, o_desc.dtype, torch.float16) - - # Update model description - self._model_desc_inputs_with_lr = [*self.model_desc.inputs, self.model_desc.learning_rate] - - # Update Mixed Precision, if applicable - if self.options.mixed_precision.enabled: - self.model_desc.loss_scale_input = self._training_session.loss_scale_input_name - self._model_desc_inputs_with_lr_and_loss_scale = [ - *self._model_desc_inputs_with_lr, - self.model_desc.loss_scale_input, - ] - self.model_desc.all_finite = _utils.get_all_gradients_finite_name_from_session(self._training_session) - self._model_desc_outputs_with_all_finite = [*self.model_desc.outputs, self.model_desc.all_finite] - elif self.options.mixed_precision.loss_scaler: - raise ValueError("Loss Scaler cannot be specified when Mixed Precision is not enabled") - - # Update Loss Scaler Input Name, if applicable - if self.options.mixed_precision.enabled and self.options.mixed_precision.loss_scaler: - self.options.mixed_precision.loss_scaler.input_name = self.model_desc.loss_scale_input.name - elif not self.options.mixed_precision.enabled and self.options.mixed_precision.loss_scaler: - raise ValueError("Loss Scaler cannot be specified when Mixed Precision is not enabled") - - # Update Gradient Accumulation, if applicable - if self.options.batch.gradient_accumulation_steps > 1: - self.model_desc.gradient_accumulation = _utils.get_gradient_accumulation_name_from_session( - self._training_session - ) - self._model_desc_outputs_with_gradient_accumulation = [ - *self.model_desc.outputs, - self.model_desc.gradient_accumulation, - ] - - # TODO: Remove when experimental checkpoint functions are removed - if self._state_dict: - checkpoint.experimental_load_state_dict(self, self._state_dict, self._load_state_dict_strict) - self._state_dict_debug = self._state_dict - self._state_dict = {} - - def _prepare_model_input(self, inputs_desc, lr, loss_scale, *inputs, **kwargs): - # Normalize input to tuple of samples - if type(inputs) == tuple and len(inputs) == 1 and type(inputs[0]) == list: # noqa: E721 - input = tuple(inputs[0]) - else: - input = inputs - - # Append input from 'kwargs' - for input_desc in inputs_desc: - if input_desc.name in kwargs: - input = (*input, kwargs[input_desc.name]) - - # Append learning rate - extra_inputs = 0 - if lr is not None: - lr = torch.tensor([lr]) - input += (lr,) - extra_inputs += 1 - - # Append loss scale - if loss_scale is not None: - assert self.options.mixed_precision.enabled, "Loss scale cannot be used without mixed precision" - loss_scale = torch.tensor([loss_scale]) - input += (loss_scale,) - extra_inputs += 1 - - # Only assert length of input when fetches is not used - assert self._train_step_info.fetches or len(self.model_desc.inputs) + extra_inputs == len(input) - return input - - def _resolve_symbolic_dimensions(self, inputs, inputs_desc, outputs_desc): - outputs = copy.deepcopy(outputs_desc) - resolved_dims = {} - for input, i_desc in zip(inputs, inputs_desc): - for i_idx, i_axis in enumerate(i_desc.shape): - if isinstance(i_axis, str): - if i_axis not in resolved_dims: - resolved_dims[i_axis] = input.size()[i_idx] - else: - assert resolved_dims[i_axis] == input.size()[i_idx], f"Mismatch in dynamic shape {i_axis}" - - for o_desc in outputs: - for idx_o, o_axis in enumerate(o_desc.shape): - if isinstance(o_axis, str): - o_desc.shape[idx_o] = resolved_dims[o_axis] - - unknown_dim = [o_desc.name for dim in o_desc.shape for o_desc in outputs if isinstance(dim, str)] - if unknown_dim: - raise RuntimeError(f"Cannot execute model with unknown output dimensions ({unknown_dim}") - - return outputs - - def _training_session_run_helper(self, is_train, inputs, inputs_desc, outputs_desc, run_options=None): - # Select IO binding - if is_train: - iobinding = self._train_io_binding - else: - iobinding = self._eval_io_binding - - # Get the list of the actual session inputs because unused inputs can be removed. - input_nodes = self._training_session.get_inputs() - input_node_names = [input_node.name for input_node in input_nodes] - - # Bind input tensors - for input, input_desc in zip(inputs, inputs_desc): - if input_desc.name in input_node_names: - device_index = _utils.get_device_index_from_input(input) - iobinding.bind_input( - input_desc.name, - input.device.type, - device_index, - _utils.dtype_torch_to_numpy(input.dtype), - list(input.size()), - input.data_ptr(), - ) - - # Bind output tensors - outputs_desc_resolved = self._resolve_symbolic_dimensions(inputs, inputs_desc, outputs_desc) - result = {} - for output_desc in outputs_desc_resolved: - target_device = self.options.device.id - if self.options.mixed_precision.enabled and output_desc.name == self.model_desc.all_finite.name: - # Keep all finite flag on CPU to match backend implementation - # This prevents CPU -> GPU -> CPU copies between frontend and backend - target_device = "cpu" - # the self.options.device may be a device that pytorch does not recognize. - # in that case, we temporary prefer to leave the input/output on CPU and let ORT session - # to move the data between device and host. - # so output will be on the same device as input. - try: - torch.device(target_device) - except Exception: - # in this case, input/output must on CPU - assert input.device.type == "cpu" - target_device = "cpu" - - torch_tensor = torch.zeros( - output_desc.shape, - device=target_device, - dtype=output_desc.dtype_amp if output_desc.dtype_amp else output_desc.dtype, - ) - iobinding.bind_output( - output_desc.name, - torch_tensor.device.type, - _utils.get_device_index(target_device), - _utils.dtype_torch_to_numpy(torch_tensor.dtype), - list(torch_tensor.size()), - torch_tensor.data_ptr(), - ) - result[output_desc.name] = torch_tensor - - # Run a train/eval step - self._training_session.run_with_iobinding(iobinding, run_options) - return result - - def _update_onnx_model_initializers(self, state_tensors): - r"""Updates ONNX graph initializers with state_tensors's values - - Usually called to save or load an ONNX model. - - The tensors names of state_tensors are compared to all ONNX initializer tensors - and when the name matches, the ONNX graph is updated with the new value. - """ - assert isinstance(state_tensors, dict), "state_tensors must be a dict" - - new_weights = [] - replace_indices = [] - for i, w in enumerate(self._onnx_model.graph.initializer): - if w.name in state_tensors: - new_weights.append(onnx.numpy_helper.from_array(state_tensors[w.name], w.name)) - replace_indices.append(i) - replace_indices.sort(reverse=True) - for w_i in replace_indices: - del self._onnx_model.graph.initializer[w_i] - self._onnx_model.graph.initializer.extend(new_weights) - - def _extract_model_states(self, state_dict, pytorch_format): - """Extract model states from the training session and load into the state_dict""" - - model_states = self._training_session.get_model_state(include_mixed_precision_weights=False) - state_dict[_utils.state_dict_model_key()] = {} - - # extract trained model weights from the training session - for precision in model_states: - state_dict[_utils.state_dict_model_key()][precision] = {} - for model_state_key in model_states[precision]: - if pytorch_format: - state_dict[_utils.state_dict_model_key()][precision][model_state_key] = torch.from_numpy( - model_states[precision][model_state_key] - ) - else: - state_dict[_utils.state_dict_model_key()][precision][model_state_key] = model_states[precision][ - model_state_key - ] - - # extract untrained (frozen) model weights - for node in self._onnx_model.graph.initializer: - if ( - node.name not in state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()] - and node.name in self.options.utils.frozen_weights - ): - if pytorch_format: - state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()][ - node.name - ] = torch.from_numpy(onnx.numpy_helper.to_array(node)) - else: - state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()][ - node.name - ] = onnx.numpy_helper.to_array(node) - - def _extract_trainer_options(self, state_dict): - """Extract relevant trainer configuration and load it into the state_dict""" - - mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key() - zero_stage = _utils.state_dict_trainer_options_zero_stage_key() - world_rank = _utils.state_dict_trainer_options_world_rank_key() - world_size = _utils.state_dict_trainer_options_world_size_key() - optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key() - D_size = _utils.state_dict_trainer_options_data_parallel_size_key() # noqa: N806 - H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() # noqa: N806 - - state_dict[_utils.state_dict_trainer_options_key()] = {} - state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] = self.options.mixed_precision.enabled - state_dict[_utils.state_dict_trainer_options_key()][ - zero_stage - ] = self.options.distributed.deepspeed_zero_optimization.stage - state_dict[_utils.state_dict_trainer_options_key()][world_rank] = self.options.distributed.world_rank - state_dict[_utils.state_dict_trainer_options_key()][world_size] = self.options.distributed.world_size - state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] = self.optim_config.name - state_dict[_utils.state_dict_trainer_options_key()][D_size] = self.options.distributed.data_parallel_size - state_dict[_utils.state_dict_trainer_options_key()][H_size] = self.options.distributed.horizontal_parallel_size - - def _extract_train_step_info(self, state_dict): - """Extract train step info settings and save it into the state_dict""" - - optimization_step = _utils.state_dict_train_step_info_optimization_step_key() - step = _utils.state_dict_train_step_info_step_key() - - state_dict[_utils.state_dict_train_step_info_key()] = {} - state_dict[_utils.state_dict_train_step_info_key()][optimization_step] = self._train_step_info.optimization_step - state_dict[_utils.state_dict_train_step_info_key()][step] = self._train_step_info.step - - def state_dict(self, pytorch_format=False): - """Returns a dictionary with model, train step info and optionally, optimizer states - - The returned dictionary contains the following information: - - Model and optimizer states - - Required ORTTrainerOptions settings - - Distributed training information, such as but not limited to ZeRO - - Train step info settings - - Structure of the returned dictionary: - - When `pytorch_format = False` - schema: - { - "model": - { - type: dict, - schema: - { - "full_precision": - { - type: dict, - schema: - { - model_weight_name: - { - type: array - } - } - } - } - }, - "optimizer": - { - type: dict, - schema: - { - model_weight_name: - { - type: dict, - schema: - { - "Moment_1": - { - type: array - }, - "Moment_2": - { - type: array - }, - "Update_Count": - { - type: array, - optional: True # present if optimizer is adam, absent otherwise - } - } - }, - "shared_optimizer_state": - { - type: dict, - optional: True, # present optimizer is shared, absent otherwise. - schema: - { - "step": - { - type: array, - } - } - } - } - }, - "trainer_options": - { - type: dict, - schema: - { - "mixed_precision": - { - type: bool - }, - "zero_stage": - { - type: int - }, - "world_rank": - { - type: int - }, - "world_size": - { - type: int - }, - "optimizer_name": - { - type: str - }, - "data_parallel_size": - { - type: int - }, - "horizontal_parallel_size": - { - type: int - } - } - }, - "partition_info": - { - type: dict, - optional: True, # present if states partitioned, else absent - schema: - { - model_weight_name: - { - type: dict, - schema: - { - "original_dim": - { - type: array - }, - "megatron_row_partition": - { - type: int - } - } - } - } - }, - "train_step_info": - { - type: dict, - schema: - { - "optimization_step": - { - type: int - }, - "step": - { - type: int - } - } - } - } - - When `pytorch_format = True` - schema: - { - model_weight_name: - { - type: tensor - } - } - - Args: - pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema - - Returns: - A dictionary with `ORTTrainer` state - """ - if not self._training_session: - warnings.warn( - "ONNX Runtime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling ORTTrainer.state_dict().", - UserWarning, - ) - return self._load_state_dict.args[0] if self._load_state_dict else {} - - state_dict = {} - - # load training session model states into the state_dict - self._extract_model_states(state_dict, pytorch_format) - if pytorch_format: - if self.options.distributed.deepspeed_zero_optimization.stage > 0: - warnings.warn("Incomplete state_dict: ZeRO enabled", UserWarning) - if self.options.distributed.horizontal_parallel_size > 1: - warnings.warn("Incomplete state_dict: Megatron enabled", UserWarning) - # if pytorch_format is true, return a flat dictionary with only model states - # which is compatible with a PyTorch model - return state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()] - - # load training session optimizer states into the state_dict - state_dict[_utils.state_dict_optimizer_key()] = self._training_session.get_optimizer_state() - - # extract the relevant training configuration from the trainer and load them into the state_dict - self._extract_trainer_options(state_dict) - - # Extract train step info settings and load it into the state_dict - self._extract_train_step_info(state_dict) - - # add partition information in case of a distributed run - if ( - self.options.distributed.deepspeed_zero_optimization.stage > 0 - or self.options.distributed.horizontal_parallel_size > 1 - ): - state_dict[_utils.state_dict_partition_info_key()] = self._training_session.get_partition_info_map() - - return state_dict - - def _load_model_states(self, state_dict, strict): - """Load the model states onto the onnx model graph""" - - if _utils.state_dict_model_key() not in state_dict: - return - - # collect all initializer names from the current onnx graph - assert self._onnx_model, "ONNX model graph is not exported" - initializer_names = {node.name for node in self._onnx_model.graph.initializer} - - # loaded_initializers dict will be loaded with all the model states from the state dictionary - # that are found in the initializer_names dictionary - loaded_initializers = {} - - # copy over model states from the input state dict onto the onnx model - for precision, precision_states in state_dict[_utils.state_dict_model_key()].items(): - for state_key, state_value in precision_states.items(): - if state_key in initializer_names: - loaded_initializers[state_key] = state_value - elif strict: - raise RuntimeError(f"Unexpected key: {state_key} in state_dict[model][{precision}]") - - # update onnx model from loaded initializers - self._update_onnx_model_initializers(loaded_initializers) - - def _load_optimizer_states(self, current_state_dict, state_dict): - """Load the optimizer states onto the training session state dictionary""" - - def _check_optimizer_mismatch(state_dict): - """Assert that the loaded optimizer has the same config as the current training session config""" - - # the state_dict optimizer_name can be a byte string (if coming from checkpoint file) - # or can be a regular string (coming from user) - optimizer_name = state_dict[_utils.state_dict_trainer_options_key()][ - _utils.state_dict_trainer_options_optimizer_name_key() - ] - - # optimizer_name can be either a regular string or a byte string. - # if it is a byte string, convert to regular string using decode() - # if it is a regular string, do nothing to it - try: # noqa: SIM105 - optimizer_name = optimizer_name.decode() - except AttributeError: - pass - assert self.optim_config.name == optimizer_name, "Optimizer mismatch: expected {}, got {}".format( - self.optim_config.name, optimizer_name - ) - - if _utils.state_dict_optimizer_key() not in state_dict: - return - - # check optimizer config names are the same for current session and the sessino being loaded - _check_optimizer_mismatch(state_dict) - - # create an entry for the optimizer in the training session state dictionary - if _utils.state_dict_optimizer_key() not in current_state_dict: - current_state_dict[_utils.state_dict_optimizer_key()] = {} - - # copy over optimizer states from the input state dict onto the training session state dict - for model_state_key, optimizer_dict in state_dict[_utils.state_dict_optimizer_key()].items(): - if model_state_key not in current_state_dict[_utils.state_dict_optimizer_key()]: - current_state_dict[_utils.state_dict_optimizer_key()][model_state_key] = {} - for optimizer_state_key, optimizer_state_value in optimizer_dict.items(): - current_state_dict[_utils.state_dict_optimizer_key()][model_state_key][ - optimizer_state_key - ] = optimizer_state_value - - def _load_state_dict_impl(self, state_dict, strict=True): - """Load the state dictionary onto the onnx model and on the training session graph""" - - # clear the callable partial - self._load_state_dict = None - - def _mismatch_keys(keys1, keys2, in_error_str, allow_unexpected=False): - """Find out the missing and the unexpected keys in two dictionaries - - Throws a runtime error if missing or unexpected keys are found - - Keys in keys1 not in keys2 will be marked as missing - - Keys in keys2 not in keys1 will be marked as unexpected - """ - keys1 = set(keys1) - keys2 = set(keys2) - missing_keys = list(keys1 - keys2) - unexpected_keys = list(keys2 - keys1) - if len(missing_keys) > 0: - raise RuntimeError(f"Missing keys: {missing_keys} in {in_error_str}") - if len(unexpected_keys) > 0 and not allow_unexpected: - raise RuntimeError(f"Unexpected keys: {unexpected_keys} in {in_error_str}") - - def _check_model_key_mismatch(current_state_dict, state_dict, allow_unexpected=False): - """Check if there is any mismatch in the model sub state dictionary between the two state_dicts""" - - # check unxexpected and missing precision keys in the model state_dict compared to the training - # session model state_dict - _mismatch_keys( - current_state_dict[_utils.state_dict_model_key()], - state_dict[_utils.state_dict_model_key()], - "state_dict[model]", - allow_unexpected, - ) - - # check for model state key mismatch - for precision_key in current_state_dict[_utils.state_dict_model_key()]: - _mismatch_keys( - current_state_dict[_utils.state_dict_model_key()][precision_key], - state_dict[_utils.state_dict_model_key()][precision_key], - f"state_dict[model][{precision_key}]", - allow_unexpected, - ) - - def _check_optimizer_key_mismatch(current_state_dict, state_dict, allow_unexpected=False): - """Check if there is any mismatch in the optimizer sub state dictionary between the two state_dicts""" - - # check for model state key mismatch for the optimizer state_dict - _mismatch_keys( - current_state_dict[_utils.state_dict_optimizer_key()], - state_dict[_utils.state_dict_optimizer_key()], - "state_dict[optimizer]", - allow_unexpected, - ) - - # check for optimizer state keys mismatch - for model_state_key in current_state_dict[_utils.state_dict_optimizer_key()]: - _mismatch_keys( - current_state_dict[_utils.state_dict_optimizer_key()][model_state_key], - state_dict[_utils.state_dict_optimizer_key()][model_state_key], - f"state_dict[optimizer][{model_state_key}]", - allow_unexpected, - ) - - def _check_key_mismatch(current_state_dict, state_dict, allow_unexpected=False): - """Check if there is a mismatch in the keys (model and optimizer) in the two state_dicts""" - - # check presence of 'model' in the input state_dict - if _utils.state_dict_model_key() in state_dict: - _check_model_key_mismatch(current_state_dict, state_dict, allow_unexpected) - else: - warnings.warn("Missing key: model in state_dict", UserWarning) - # check presence of 'optimizer' in the input state_dict - if _utils.state_dict_optimizer_key() in state_dict: - _check_optimizer_key_mismatch(current_state_dict, state_dict, allow_unexpected) - else: - warnings.warn("Missing key: optimizer in state_dict", UserWarning) - - # extract state dict from the current training session. this is to persist the states between - # two training sessions. - # for example, if user provided only the model states, the optimizer states from the current - # training session must be persisted - current_state_dict = {} - if self._training_session: - current_state_dict = self.state_dict() - if strict: - # for Zero enabled, the current trainer might not have the complete state, and we must allow - # extra keys to be present in the state dict - allow_unexpected = self.options.distributed.deepspeed_zero_optimization.stage > 0 - _check_key_mismatch(current_state_dict, state_dict, allow_unexpected) - - # load the model states from the input state dictionary into the onnx graph - self._load_model_states(state_dict, strict) - - # load the optimizer states from the input state dictionary into the training session states - # dictionary - self._load_optimizer_states(current_state_dict, state_dict) - - return ( - current_state_dict[_utils.state_dict_optimizer_key()] - if _utils.state_dict_optimizer_key() in current_state_dict - else {} - ) - - def _load_train_step_info(self, state_dict): - """Load the train step info settings from state dict""" - - if _utils.state_dict_train_step_info_key() not in state_dict: - warnings.warn("Missing key: train_step_info in state_dict", UserWarning) - return - - optimization_step = _utils.state_dict_train_step_info_optimization_step_key() - step = _utils.state_dict_train_step_info_step_key() - - self._train_step_info.optimization_step = state_dict[_utils.state_dict_train_step_info_key()][optimization_step] - self._train_step_info.step = state_dict[_utils.state_dict_train_step_info_key()][step] - - def load_state_dict(self, state_dict, strict=True): - """Loads state_dict containing model/optimizer states into ORTTrainer - - The state_dict dictionary may contain the following information: - - Model and optimizer states - - Required ORTTrainerOptions settings - - Distributed training information, such as but not limited to ZeRO - - Args: - state_dict: state dictionary containing both model and optimizer states. The structure of this dictionary - should be the same as the one that is returned by ORTTrainer.state_dict for the case when pytorch_format=False - strict: boolean flag to strictly enforce that the input state_dict keys match the keys from ORTTrainer.state_dict - """ - - # if onnx graph has not been initialized, loading of states will be put on hold. - # a copy of the state_dict and other arguments to the function will be stored until the onnx graph has - # been initialized. Once the graph is initialized, the desired states will be loaded onto the grpah - if not self._training_session: - self._load_state_dict = partial(self._load_state_dict_impl, state_dict, strict=strict) - return - - # load the train step info settings - self._load_train_step_info(state_dict) - - # load states onto the frontend onnx graph - optimizer_state_dict = self._load_state_dict_impl(state_dict, strict=strict) - - # create a new training session after loading initializer states onto the onnx graph - # pass the populated states to the training session to populate the backend graph - self._init_session( - optimizer_state_dict, - session_options=self.options.session_options, - provider_options=self.options._validated_opts["provider_options"], - ) - - def save_checkpoint(self, path, user_dict={}, include_optimizer_states=True): # noqa: B006 - """Persists ORTTrainer state dictionary on disk along with user_dict. - - Saves the state_dict along with the user_dict to a file specified by path. - - Args: - path: string representation to a file path or a python file-like object. - if file already exists at path, an exception is raised. - user_dict: custom data to be saved along with the state_dict. This data will be returned - to the user when load_checkpoint is called. - include_optimizer_states: boolean flag indicating whether or not to persist the optimizer states. - on load_checkpoint, only model states will be loaded if include_optimizer_states==True - """ - - # extract state_dict to be saved in the checkpoint - state_dict = self.state_dict() - - # if user_dict is provided, serialize to bytes and convert to hex string. - # this helps in loading the types as they are given by the user since hdf5 - # converts to numpy types otherwise - if bool(user_dict): - state_dict[_utils.state_dict_user_dict_key()] = _checkpoint_storage.to_serialized_hex(user_dict) - - # if include_optimizer_states is False, only save the model states in the checkpoint file - if not include_optimizer_states: - if _utils.state_dict_optimizer_key() in state_dict: - del state_dict[_utils.state_dict_optimizer_key()] - - _checkpoint_storage.save(state_dict, path) - - def _aggregation_required(self, loaded_trainer_options): - """Checks if aggregation is required for the loading the state_dict into the ORTTrainer""" - - # To load states in the backend, aggregation is required for every ZeRO - # or Megatron checkpoint - return ( - loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 - or loaded_trainer_options[_utils.state_dict_trainer_options_horizontal_parallel_size_key()] > 1 - ) - - def load_checkpoint(self, *paths, strict=True): - """Loads the saved checkpoint state dictionary into the ORTTrainer - - Reads the saved checkpoint files specified by paths from disk and loads the state dictionary - onto the ORTTrainer. - Aggregates the checkpoint files if aggregation is required. - - Args: - paths: one or more files represented as strings where the checkpoint is saved - strict: boolean flag to strictly enforce that the saved checkpoint state_dict - keys match the keys from ORTTrainer.state_dict - Returns: - dictionary that the user had saved when calling save_checkpoint - """ - state_dict = {} - - # check if aggregation is required - loaded_trainer_options = _checkpoint_storage.load(paths[0], key=_utils.state_dict_trainer_options_key()) - if self._aggregation_required(loaded_trainer_options): - # if aggregation is required, aggregation logic must be run on the saved checkpoints - state_dict = checkpoint.aggregate_checkpoints(paths, pytorch_format=False) - else: - # if aggregation is not required, there must only be a single file that needs to be loaded - assert len(paths) == 1, f"Expected number of files to load: 1, got {len(paths)}" - state_dict = _checkpoint_storage.load(paths[0]) - - # extract user dict from the saved checkpoint - user_dict = {} - if _utils.state_dict_user_dict_key() in state_dict: - user_dict = _checkpoint_storage.from_serialized_hex(state_dict[_utils.state_dict_user_dict_key()]) - del state_dict[_utils.state_dict_user_dict_key()] - - self.load_state_dict(state_dict, strict=strict) - - return user_dict diff --git a/orttraining/orttraining/python/training/orttrainer_options.py b/orttraining/orttraining/python/training/orttrainer_options.py deleted file mode 100644 index c63ac6f82c87f..0000000000000 --- a/orttraining/orttraining/python/training/orttrainer_options.py +++ /dev/null @@ -1,692 +0,0 @@ -import cerberus - -import onnxruntime as ort -from onnxruntime.capi._pybind_state import PropagateCastOpsStrategy - -from .amp import loss_scaler -from .optim import lr_scheduler - - -class ORTTrainerOptions: - r"""Settings used by ONNX Runtime training backend - - The parameters are hierarchically organized to facilitate configuration through semantic groups - that encompasses features, such as distributed training, etc. - - Input validation is performed on the input dict during instantiation to ensure - that supported parameters and values are passed in. Invalid input results - in :py:obj:`ValueError` exception with details on it. - - Args: - options (dict): contains all training options - _validate (bool, default is True): for internal use only - - Supported schema for kwargs: - - .. code-block:: python - - schema = { - 'batch' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'gradient_accumulation_steps' : { - 'type' : 'integer', - 'min' : 1, - 'default' : 1 - } - }, - }, - 'device' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'id' : { - 'type' : 'string', - 'default' : 'cuda' - }, - 'mem_limit' : { - 'type' : 'integer', - 'min' : 0, - 'default' : 0 - } - } - }, - 'distributed': { - 'type': 'dict', - 'default': {}, - 'required': False, - 'schema': { - 'world_rank': { - 'type': 'integer', - 'min': 0, - 'default': 0 - }, - 'world_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'local_rank': { - 'type': 'integer', - 'min': 0, - 'default': 0 - }, - 'data_parallel_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'horizontal_parallel_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'pipeline_parallel' : { - 'type': 'dict', - 'default': {}, - 'required': False, - 'schema': { - 'pipeline_parallel_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'num_pipeline_micro_batches': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'pipeline_cut_info_string': { - 'type': 'string', - 'default': '' - }, - 'sliced_schema': { - 'type': 'dict', - 'default': {}, - 'keysrules': {'type': 'string'}, - 'valuesrules': { - 'type': 'list', - 'schema': {'type': 'integer'} - } - }, - 'sliced_axes': { - 'type': 'dict', - 'default': {}, - 'keysrules': {'type': 'string'}, - 'valuesrules': {'type': 'integer'} - }, - 'sliced_tensor_names': { - 'type': 'list', - 'schema': {'type': 'string'}, - 'default': [] - } - } - }, - 'allreduce_post_accumulation': { - 'type': 'boolean', - 'default': False - }, - 'deepspeed_zero_optimization': { - 'type': 'dict', - 'default': {}, - 'required': False, - 'schema': { - 'stage': { - 'type': 'integer', - 'min': 0, - 'max': 1, - 'default': 0 - }, - } - }, - 'enable_adasum': { - 'type': 'boolean', - 'default': False - } - } - }, - 'lr_scheduler' : { - 'type' : 'optim.lr_scheduler', - 'nullable' : True, - 'default' : None - }, - 'mixed_precision' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'enabled' : { - 'type' : 'boolean', - 'default' : False - }, - 'loss_scaler' : { - 'type' : 'amp.loss_scaler', - 'nullable' : True, - 'default' : None - } - } - }, - 'graph_transformer': { - 'type': 'dict', - 'required': False, - 'default': {}, - 'schema': { - 'attn_dropout_recompute': { - 'type': 'boolean', - 'default': False - }, - 'gelu_recompute': { - 'type': 'boolean', - 'default': False - }, - 'transformer_layer_recompute': { - 'type': 'boolean', - 'default': False - }, - 'number_recompute_layers': { - 'type': 'integer', - 'min': 0, - 'default': 0 - }, - 'propagate_cast_ops_config': { - 'type': 'dict', - 'required': False, - 'default': {}, - 'schema': { - 'propagate_cast_ops_strategy': { - 'type': 'onnxruntime.training.PropagateCastOpsStrategy', - 'default': PropagateCastOpsStrategy.FLOOD_FILL - }, - 'propagate_cast_ops_level': { - 'type': 'integer', - 'default': 1 - }, - 'propagate_cast_ops_allow': { - 'type': 'list', - 'schema': {'type': 'string'}, - 'default': [] - } - } - } - } - }, - 'utils' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'frozen_weights' : { - 'type' : 'list', - 'default' : [] - }, - 'grad_norm_clip' : { - 'type' : 'boolean', - 'default' : True - }, - 'memory_efficient_gradient' : { - 'type' : 'boolean', - 'default' : False - }, - 'run_symbolic_shape_infer' : { - 'type' : 'boolean', - 'default' : False - } - } - }, - 'debug' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'deterministic_compute' : { - 'type' : 'boolean', - 'default' : False - }, - 'check_model_export' : { - 'type' : 'boolean', - 'default' : False - }, - 'graph_save_paths' : { - 'type' : 'dict', - 'default': {}, - 'required': False, - 'schema': { - 'model_after_graph_transforms_path': { - 'type': 'string', - 'default': '' - }, - 'model_with_gradient_graph_path':{ - 'type': 'string', - 'default': '' - }, - 'model_with_training_graph_path': { - 'type': 'string', - 'default': '' - }, - 'model_with_training_graph_after_optimization_path': { - 'type': 'string', - 'default': '' - }, - } - }, - } - }, - '_internal_use' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'enable_internal_postprocess' : { - 'type' : 'boolean', - 'default' : True - }, - 'extra_postprocess' : { - 'type' : 'callable', - 'nullable' : True, - 'default' : None - }, - 'onnx_opset_version': { - 'type': 'integer', - 'min' : 12, - 'max' :14, - 'default': 14 - }, - 'enable_onnx_contrib_ops' : { - 'type' : 'boolean', - 'default' : True - } - } - }, - 'provider_options':{ - 'type': 'dict', - 'default': {}, - 'required': False, - 'schema': {} - }, - 'session_options': { - 'type': 'SessionOptions', - 'nullable': True, - 'default': None - }, - } - - Keyword arguments: - batch (dict): - batch related settings - batch.gradient_accumulation_steps (int, default is 1): - number of steps to accumulate before do collective gradient reduction - device (dict): - compute device related settings - device.id (string, default is 'cuda'): - device to run training - device.mem_limit (int): - maximum memory size (in bytes) used by device.id - distributed (dict): - distributed training options. - distributed.world_rank (int, default is 0): - rank ID used for data/horizontal parallelism - distributed.world_size (int, default is 1): - number of ranks participating in parallelism - distributed.data_parallel_size (int, default is 1): - number of ranks participating in data parallelism - distributed.horizontal_parallel_size (int, default is 1): - number of ranks participating in horizontal parallelism - distributed.pipeline_parallel (dict): - Options which are only useful to pipeline parallel. - distributed.pipeline_parallel.pipeline_parallel_size (int, default is 1): - number of ranks participating in pipeline parallelism - distributed.pipeline_parallel.num_pipeline_micro_batches (int, default is 1): - number of micro-batches. We divide input batch into micro-batches and run the graph. - distributed.pipeline_parallel.pipeline_cut_info_string (string, default is ''): - string of cutting ids for pipeline partition. - distributed.allreduce_post_accumulation (bool, default is False): - True enables overlap of AllReduce with computation, while False, - postpone AllReduce until all gradients are ready - distributed.deepspeed_zero_optimization: - DeepSpeed ZeRO options. - distributed.deepspeed_zero_optimization.stage (int, default is 0): - select which stage of DeepSpeed ZeRO to use. Stage 0 means disabled. - distributed.enable_adasum (bool, default is False): - enable `Adasum `_ - algorithm for AllReduce - lr_scheduler (optim._LRScheduler, default is None): - specifies learning rate scheduler - mixed_precision (dict): - mixed precision training options - mixed_precision.enabled (bool, default is False): - enable mixed precision (fp16) - mixed_precision.loss_scaler (amp.LossScaler, default is None): - specifies a loss scaler to be used for fp16. If not specified, - :py:class:`.DynamicLossScaler` is used with default values. - Users can also instantiate :py:class:`.DynamicLossScaler` and - override its parameters. Lastly, a completely new implementation - can be specified by extending :py:class:`.LossScaler` class from scratch - graph_transformer (dict): - graph transformer related configurations - graph_transformer.attn_dropout_recompute(bool, default False) - graph_transformer.gelu_recompute(bool, default False) - graph_transformer.transformer_layer_recompute(bool, default False) - graph_transformer.number_recompute_layers(bool, default False) - graph_transformer.propagate_cast_ops_config (dict): - graph_transformer.propagate_cast_ops_config.strategy(PropagateCastOpsStrategy, default FLOOD_FILL) - Specify the choice of the cast propagation optimization strategy, either, NONE, INSERT_AND_REDUCE or FLOOD_FILL. - NONE strategy does not perform any cast propagation transformation on the graph, although other optimizations - locally change cast operations, for example, in order to fuse Transpose and MatMul nodes, the TransposeMatMulFunsion optimization could - interchange Transpose and Cast if the Cast node exists between Transpose and MatMul. - INSERT_AND_REDUCE strategy inserts and reduces cast operations around the nodes with allowed opcodes. - FLOOD_FILL strategy expands float16 regions in the graph using the allowed opcodes, and unlike - INSERT_AND_REDUCE does not touch opcodes outside expanded float16 region. - graph_transformer.propagate_cast_ops_config.level(integer, default 1) - Optimize by moving Cast operations if propagate_cast_ops_level is non-negative. - Use predetermined list of opcodes considered safe to move before/after cast operation - if propagate_cast_ops_level is positive and use propagate_cast_ops_allow otherwise. - graph_transformer.propagate_cast_ops_config.allow(list of str, []) - List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero. - attn_dropout_recompute (bool, default is False): - enable recomputing attention dropout to save memory - gelu_recompute (bool, default is False): - enable recomputing Gelu activation output to save memory - transformer_layer_recompute (bool, default is False): - enable recomputing transformer layerwise to save memory - number_recompute_layers (int, default is 0) - number of layers to apply transformer_layer_recompute, by default system will - apply recompute to all the layers, except for the last one - utils (dict): - miscellaneous options - utils.frozen_weights (list of str, []): - list of model parameter names to skip training (weights don't change) - utils.grad_norm_clip (bool, default is True): - enables gradient norm clipping for 'AdamOptimizer' and 'LambOptimizer' - utils.memory_efficient_gradient (bool, default is False): - enables use of memory aware gradient builder. - utils.run_symbolic_shape_infer (bool, default is False): - runs symbolic shape inference on the model - debug (dict): - debug options - debug.deterministic_compute (bool, default is False) - forces compute to be deterministic accross runs - debug.check_model_export (bool, default is False) - compares PyTorch model outputs with ONNX model outputs in inference before the first - train step to ensure successful model export - debug.graph_save_paths (dict): - paths used for dumping ONNX graphs for debugging purposes - debug.graph_save_paths.model_after_graph_transforms_path (str, default is "") - path to export the ONNX graph after training-related graph transforms have been applied. - No output when it is empty. - debug.graph_save_paths.model_with_gradient_graph_path (str, default is "") - path to export the ONNX graph with the gradient graph added. No output when it is empty. - debug.graph_save_paths.model_with_training_graph_path (str, default is "") - path to export the training ONNX graph with forward, gradient and optimizer nodes. - No output when it is empty. - debug.graph_save_paths.model_with_training_graph_after_optimization_path (str, default is "") - outputs the optimized training graph to the path if nonempty. - _internal_use (dict): - internal options, possibly undocumented, that might be removed without notice - _internal_use.enable_internal_postprocess (bool, default is True): - enable internal internal post processing of the ONNX model - _internal_use.extra_postprocess (callable, default is None) - a functor to postprocess the ONNX model and return a new ONNX model. - It does not override :py:attr:`._internal_use.enable_internal_postprocess`, but complement it - _internal_use.onnx_opset_version (int, default is 14): - ONNX opset version used during model exporting. - _internal_use.enable_onnx_contrib_ops (bool, default is True) - enable PyTorch to export nodes as contrib ops in ONNX. - This flag may be removed anytime in the future. - session_options (onnxruntime.SessionOptions): - The SessionOptions instance that TrainingSession will use. - provider_options (dict): - The provider_options for customized execution providers. it is dict map from EP name to - a key-value pairs, like {'EP1' : {'key1' : 'val1'}, ....} - - Example: - .. code-block:: python - - opts = ORTTrainerOptions({ - 'batch' : { - 'gradient_accumulation_steps' : 128 - }, - 'device' : { - 'id' : 'cuda:0', - 'mem_limit' : 2*1024*1024*1024, - }, - 'lr_scheduler' : optim.lr_scheduler.LinearWarmupLRScheduler(), - 'mixed_precision' : { - 'enabled': True, - 'loss_scaler': amp.LossScaler(loss_scale=float(1 << 16)) - } - }) - fp16_enabled = opts.mixed_precision.enabled - """ - - def __init__(self, options={}): # noqa: B006 - # Keep a copy of original input for debug - self._original_opts = dict(options) - - # Used for logging purposes - self._main_class_name = self.__class__.__name__ - - # Validates user input - self._validated_opts = dict(self._original_opts) - validator = ORTTrainerOptionsValidator(_ORTTRAINER_OPTIONS_SCHEMA) - self._validated_opts = validator.validated(self._validated_opts) - if self._validated_opts is None: - raise ValueError(f"Invalid options: {validator.errors}") - - # Convert dict in object - for k, v in self._validated_opts.items(): - setattr(self, k, self._wrap(v)) - - def __repr__(self): - return "{%s}" % str( - ", ".join( - f"'{k}': {v!r}" - for (k, v) in self.__dict__.items() - if k not in ["_original_opts", "_validated_opts", "_main_class_name"] - ) - ) - - def _wrap(self, v): - if isinstance(v, (tuple, list, set, frozenset)): - return type(v)([self._wrap(i) for i in v]) - else: - return _ORTTrainerOptionsInternal(self._main_class_name, v) if isinstance(v, dict) else v - - -class _ORTTrainerOptionsInternal(ORTTrainerOptions): - r"""Internal class used by ONNX Runtime training backend for input validation - - NOTE: Users MUST NOT use this class in any way! - """ - - def __init__(self, main_class_name, options): - # Used for logging purposes - self._main_class_name = main_class_name - # We don't call super().__init__(options) here but still called it "_validated_opts" - # instead of "_original_opts" because it has been validated in the top-level - # ORTTrainerOptions's constructor. - self._validated_opts = dict(options) - # Convert dict in object - for k, v in dict(options).items(): - setattr(self, k, self._wrap(v)) - - -class ORTTrainerOptionsValidator(cerberus.Validator): - _LR_SCHEDULER = cerberus.TypeDefinition("lr_scheduler", (lr_scheduler._LRScheduler,), ()) - _LOSS_SCALER = cerberus.TypeDefinition("loss_scaler", (loss_scaler.LossScaler,), ()) - - _SESSION_OPTIONS = cerberus.TypeDefinition("session_options", (ort.SessionOptions,), ()) - - _PROPAGATE_CAST_OPS_STRATEGY = cerberus.TypeDefinition( - "propagate_cast_ops_strategy", (PropagateCastOpsStrategy,), () - ) - - types_mapping = cerberus.Validator.types_mapping.copy() - types_mapping["lr_scheduler"] = _LR_SCHEDULER - types_mapping["loss_scaler"] = _LOSS_SCALER - types_mapping["session_options"] = _SESSION_OPTIONS - types_mapping["propagate_cast_ops_strategy"] = _PROPAGATE_CAST_OPS_STRATEGY - - -def _check_is_callable(field, value, error): - result = False - try: - # Python 3 - result = value is None or callable(value) - except Exception: - # Python 3 but < 3.2 - if hasattr(value, "__call__"): # noqa: B004 - result = True - if not result: - error(field, "Must be callable or None") - - -_ORTTRAINER_OPTIONS_SCHEMA = { - "batch": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": {"gradient_accumulation_steps": {"type": "integer", "min": 1, "default": 1}}, - }, - "device": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "id": {"type": "string", "default": "cuda"}, - "mem_limit": {"type": "integer", "min": 0, "default": 0}, - }, - }, - "distributed": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "world_rank": {"type": "integer", "min": 0, "default": 0}, - "world_size": {"type": "integer", "min": 1, "default": 1}, - "local_rank": {"type": "integer", "min": 0, "default": 0}, - "data_parallel_size": {"type": "integer", "min": 1, "default": 1}, - "horizontal_parallel_size": {"type": "integer", "min": 1, "default": 1}, - "pipeline_parallel": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "pipeline_parallel_size": {"type": "integer", "min": 1, "default": 1}, - "num_pipeline_micro_batches": {"type": "integer", "min": 1, "default": 1}, - "pipeline_cut_info_string": {"type": "string", "default": ""}, - "sliced_schema": { - "type": "dict", - "default_setter": lambda _: {}, - "keysrules": {"type": "string"}, - "valuesrules": {"type": "list", "schema": {"type": "integer"}}, - }, - "sliced_axes": { - "type": "dict", - "default_setter": lambda _: {}, - "keysrules": {"type": "string"}, - "valuesrules": {"type": "integer"}, - }, - "sliced_tensor_names": {"type": "list", "schema": {"type": "string"}, "default": []}, - }, - }, - "allreduce_post_accumulation": {"type": "boolean", "default": False}, - "deepspeed_zero_optimization": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "stage": {"type": "integer", "min": 0, "max": 1, "default": 0}, - }, - }, - "enable_adasum": {"type": "boolean", "default": False}, - }, - }, - "lr_scheduler": {"type": "lr_scheduler", "nullable": True, "default": None}, - "mixed_precision": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "enabled": {"type": "boolean", "default": False}, - "loss_scaler": {"type": "loss_scaler", "nullable": True, "default": None}, - }, - }, - "graph_transformer": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "attn_dropout_recompute": {"type": "boolean", "default": False}, - "gelu_recompute": {"type": "boolean", "default": False}, - "transformer_layer_recompute": {"type": "boolean", "default": False}, - "number_recompute_layers": {"type": "integer", "min": 0, "default": 0}, - "propagate_cast_ops_config": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "strategy": { - "type": "propagate_cast_ops_strategy", - "nullable": True, - "default": PropagateCastOpsStrategy.FLOOD_FILL, - }, - "level": {"type": "integer", "min": -1, "default": 1}, - "allow": {"type": "list", "schema": {"type": "string"}, "default": []}, - }, - }, - }, - }, - "utils": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "frozen_weights": {"type": "list", "default": []}, - "grad_norm_clip": {"type": "boolean", "default": True}, - "memory_efficient_gradient": {"type": "boolean", "default": False}, - "run_symbolic_shape_infer": {"type": "boolean", "default": False}, - }, - }, - "debug": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "deterministic_compute": {"type": "boolean", "default": False}, - "check_model_export": {"type": "boolean", "default": False}, - "graph_save_paths": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "model_after_graph_transforms_path": {"type": "string", "default": ""}, - "model_with_gradient_graph_path": {"type": "string", "default": ""}, - "model_with_training_graph_path": {"type": "string", "default": ""}, - "model_with_training_graph_after_optimization_path": {"type": "string", "default": ""}, - }, - }, - }, - }, - "_internal_use": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "enable_internal_postprocess": {"type": "boolean", "default": True}, - "extra_postprocess": {"check_with": _check_is_callable, "nullable": True, "default": None}, - "onnx_opset_version": {"type": "integer", "min": 12, "max": 14, "default": 14}, - "enable_onnx_contrib_ops": {"type": "boolean", "default": True}, - }, - }, - "provider_options": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "allow_unknown": True, - "schema": {}, - }, - "session_options": {"type": "session_options", "nullable": True, "default": None}, -} diff --git a/orttraining/orttraining/python/training/postprocess.py b/orttraining/orttraining/python/training/postprocess.py deleted file mode 100644 index 6c2adb6af7978..0000000000000 --- a/orttraining/orttraining/python/training/postprocess.py +++ /dev/null @@ -1,478 +0,0 @@ -import os.path # noqa: F401 -import struct -import sys # noqa: F401 - -import numpy as np # noqa: F401 -import onnx -from onnx import * # noqa: F403 -from onnx import helper, numpy_helper # noqa: F401 - - -def run_postprocess(model): - # this post pass is not required for pytorch >= 1.5 - # where add_node_name in torch.onnx.export is default to True - model = add_name(model) - - # this post pass is not required for pytorch > 1.6 - model = fuse_softmaxNLL_to_softmaxCE(model) - - model = fix_expand_shape(model) - model = fix_expand_shape_pt_1_5(model) - return model - - -def find_input_node(model, arg): - result = [] - for node in model.graph.node: - for output in node.output: - if output == arg: - result.append(node) - return result[0] if len(result) == 1 else None - - -def find_output_node(model, arg): - result = [] - for node in model.graph.node: - for input in node.input: - if input == arg: - result.append(node) - return result[0] if len(result) == 1 else result - - -def add_name(model): - i = 0 - for node in model.graph.node: - node.name = "%s_%d" % (node.op_type, i) - i += 1 - return model - - -# Expand Shape PostProcess - - -def fix_expand_shape(model): - expand_nodes = [n for n in model.graph.node if n.op_type == "Expand"] - model_inputs_names = [i.name for i in model.graph.input] - - for expand_node in expand_nodes: - shape = find_input_node(model, expand_node.input[1]) - if shape.op_type == "Shape": - # an expand subgraph - # Input Input2 - # | | - # | Shape - # | | - # |__ __| - # | | - # Expand - # | - # output - # - # Only if Input2 is one of the model inputs, assign Input2's shape to output of expand. - shape_input_name = shape.input[0] - if shape_input_name in model_inputs_names: - index = model_inputs_names.index(shape_input_name) - expand_out = model.graph.value_info.add() - expand_out.name = expand_node.output[0] - expand_out.type.CopyFrom(model.graph.input[index].type) - return model - - -def fix_expand_shape_pt_1_5(model): - # expand subgraph - # Constant - # + - # ConstantOfShape - # | + | - # | + | - # (Reshape subgraph) Mul | - # |___ _________| | - # + | | | - # + Equal | - # +++++|++++++++++++++|++ - # |____________ | + - # | | + - # (subgraph) Where - # | | - # |_____ ___________| - # | | - # Expand - # | - # output - # - # where the Reshape subgraph is - # - # Input - # | | - # | |___________________ - # | | - # Shape Constant Shape Constant - # | ______| | ______| - # | | | | - # Gather Gather - # | | - # Unsqueeze Unsqueeze - # | | - # | ..Number of dims.. | - # | _________________| - # |...| - # Concat Constant - # | | - # |______ __________________| - # | | - # Reshape - # | - # output - # - # This pass will copy Input's shape to the output of Expand. - expand_nodes = [n for n in model.graph.node if n.op_type == "Expand"] - model_inputs_names = [i.name for i in model.graph.input] - - for expand_node in expand_nodes: - n_where = find_input_node(model, expand_node.input[1]) - if n_where.op_type != "Where": - continue - - n_equal = find_input_node(model, n_where.input[0]) - n_cos = find_input_node(model, n_where.input[1]) - n_reshape = find_input_node(model, n_where.input[2]) - - if n_equal.op_type != "Equal" or n_cos.op_type != "ConstantOfShape" or n_reshape.op_type != "Reshape": - continue - - n_reshape_e = find_input_node(model, n_equal.input[0]) - n_mul = find_input_node(model, n_equal.input[1]) - if n_reshape_e != n_reshape or n_mul.op_type != "Mul": - continue - - n_cos_m = find_input_node(model, n_mul.input[0]) - n_constant = find_input_node(model, n_mul.input[1]) - if n_cos_m != n_cos or n_constant.op_type != "Constant": - continue - - n_concat = find_input_node(model, n_reshape.input[0]) - n_constant_r = find_input_node(model, n_reshape.input[1]) - if n_concat.op_type != "Concat" or n_constant_r.op_type != "Constant": - continue - - n_input_candidates = [] - for concat_in in n_concat.input: - n_unsqueeze = find_input_node(model, concat_in) - if n_unsqueeze.op_type != "Unsqueeze": - break - n_gather = find_input_node(model, n_unsqueeze.input[0]) - if n_gather.op_type != "Gather": - break - n_shape = find_input_node(model, n_gather.input[0]) - n_constant_g = find_input_node(model, n_gather.input[1]) - if n_shape.op_type != "Shape" or n_constant_g.op_type != "Constant": - break - n_input = n_shape.input[0] - if n_input not in model_inputs_names: - break - n_input_candidates.append(n_input) - - if not n_input_candidates or not all(elem == n_input_candidates[0] for elem in n_input_candidates): - continue - - index = model_inputs_names.index(n_input_candidates[0]) - expand_out = model.graph.value_info.add() - expand_out.name = expand_node.output[0] - expand_out.type.CopyFrom(model.graph.input[index].type) - return model - - -# LayerNorm PostProcess - - -def find_nodes(graph, op_type): - nodes = [] - for node in graph.node: - if node.op_type == op_type: - nodes.append(node) - return nodes - - -def is_type(node, op_type): - if node is None or isinstance(node, list): - return False - return node.op_type == op_type - - -def add_const(model, name, output, t_value=None, f_value=None): - const_node = model.graph.node.add() - const_node.op_type = "Constant" - const_node.name = name - const_node.output.extend([output]) - attr = const_node.attribute.add() - attr.name = "value" - if t_value is not None: - attr.type = 4 - attr.t.CopyFrom(t_value) - else: - attr.type = 1 - attr.f = f_value - return const_node - - -def layer_norm_transform(model): - # DEPRECATED: This pass is no longer needed as the transform is handled at the backend. - # Converting below subgraph - # - # input - # | - # ReduceMean - # | - # Sub Constant - # _||_____ | - # | | | - # | | | - # | (optional) Cast (optional) Cast - # | | | - # | | ____________________| - # | | | - # | Pow - # | | - # | ReduceMean - # | | - # | Add - # | | - # |__ __Sqrt - # | | - # Div (weight) - # | | - # | _____| - # | | - # Mul (bias) - # | | - # | _____| - # | | - # Add - # | - # output - # - # to the below subgraph - # - # input (weight) (bias) - # | | | - # | _______| | - # | | ________________| - # | | | - # LayerNormalization - # | - # output - graph = model.graph - - nodes_ReduceMean = find_nodes(graph, "ReduceMean") # noqa: N806 - - id = 0 - layer_norm_nodes = [] - remove_nodes = [] - for reduce_mean in nodes_ReduceMean: - # check that reduce_mean output is Sub - sub = find_output_node(model, reduce_mean.output[0]) - if not is_type(sub, "Sub"): - continue - - # check that sub output[0] is Div and output[1] is Pow - pow, div = find_output_node(model, sub.output[0]) - if is_type(pow, "Cast"): - # During an update in PyTorch, Cast nodes are inserted between Sub and Pow. - remove_nodes += [pow] - pow = find_output_node(model, pow.output[0]) - if not is_type(pow, "Pow"): - continue - cast_pow = find_input_node(model, pow.input[1]) - if not is_type(cast_pow, "Cast"): - continue - remove_nodes += [cast_pow] - if not is_type(div, "Div") or not is_type(pow, "Pow"): - continue - - # check that pow ouput is ReduceMean - reduce_mean2 = find_output_node(model, pow.output[0]) - if not is_type(reduce_mean2, "ReduceMean"): - continue - - # check that reduce_mean2 output is Add - add = find_output_node(model, reduce_mean2.output[0]) - if not is_type(add, "Add"): - continue - - # check that add output is Sqrt - sqrt = find_output_node(model, add.output[0]) - if not is_type(sqrt, "Sqrt"): - continue - - # check that sqrt output is div - if div != find_output_node(model, sqrt.output[0]): - continue - - # check if div output is Mul - optional_mul = find_output_node(model, div.output[0]) - if not is_type(optional_mul, "Mul"): - optional_mul = None - continue # default bias and weight not supported - - # check if mul output is Add - if optional_mul is not None: - optional_add = find_output_node(model, optional_mul.output[0]) - else: - optional_add = find_output_node(model, div.output[0]) - if not is_type(optional_add, "Add"): - optional_add = None - continue # default bias and weight not supported - - # add nodes to remove_nodes - remove_nodes.extend([reduce_mean, sub, div, pow, reduce_mean2, add, sqrt]) - - # create LayerNorm node - layer_norm_input = [] - layer_norm_output = [] - - layer_norm_input.append(reduce_mean.input[0]) - - if optional_mul is not None: - remove_nodes.append(optional_mul) - weight = optional_mul.input[1] - layer_norm_input.append(weight) - - if optional_add is not None: - remove_nodes.append(optional_add) - bias = optional_add.input[1] - layer_norm_input.append(bias) - - if optional_add is not None: - layer_norm_output.append(optional_add.output[0]) - elif optional_mul is not None: - layer_norm_output.append(optional_mul.output[0]) - else: - layer_norm_output.append(div.output[0]) - - layer_norm_output.append("saved_mean_" + str(id)) - layer_norm_output.append("saved_inv_std_var_" + str(id)) - - epsilon_node = find_input_node(model, add.input[1]) - epsilon = epsilon_node.attribute[0].t.raw_data - epsilon = struct.unpack("f", epsilon)[0] - - layer_norm = helper.make_node( - "LayerNormalization", - layer_norm_input, - layer_norm_output, - "LayerNormalization_" + str(id), - None, - axis=reduce_mean.attribute[0].ints[0], - epsilon=epsilon, - ) - layer_norm_nodes.append(layer_norm) - id += 1 - - # remove orphan constant nodes - for constant in graph.node: - if constant.op_type == "Constant" and constant not in remove_nodes: - is_orphan = True - for out_name in constant.output: - out = find_output_node(model, out_name) - if out not in remove_nodes: - is_orphan = False - if is_orphan: - remove_nodes.append(constant) - - all_nodes = [] - for node in graph.node: - if node not in remove_nodes: - all_nodes.append(node) - - for node in layer_norm_nodes: - all_nodes.append(node) # noqa: PERF402 - - graph.ClearField("node") - graph.node.extend(all_nodes) - return model - - -# Fuse SoftmaxCrossEntropy - - -def fuse_softmaxNLL_to_softmaxCE(onnx_model): # noqa: N802 - # Converting below subgraph - # - # (subgraph) - # | - # LogSoftmax (target) (optional weight) - # | | | - # nll_loss/NegativeLogLikelihoodLoss - # | - # output - # - # to the following - # - # (subgraph) (target) (optional weight) - # | | _____| - # | | | - # SparseSoftmaxCrossEntropy - # | - # output - nll_count = 0 - while True: - nll_count = nll_count + 1 - nll_loss_node = None - nll_loss_node_index = 0 - for nll_loss_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007 - if node.op_type == "nll_loss" or node.op_type == "NegativeLogLikelihoodLoss": - nll_loss_node = node - break - - if nll_loss_node is None: - break - - softmax_node = None - softmax_node_index = 0 - label_input_name = None - weight_input_name = None - for softmax_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007 - if node.op_type == "LogSoftmax": - # has to be connected to nll_loss - if len(nll_loss_node.input) > 2: - weight_input_name = nll_loss_node.input[2] - if node.output[0] == nll_loss_node.input[0]: - softmax_node = node - label_input_name = nll_loss_node.input[1] - break - elif node.output[0] == nll_loss_node.input[1]: - softmax_node = node - label_input_name = nll_loss_node.input[0] - break - else: - if softmax_node is not None: - break - - if softmax_node is None: - break - - # delete nll_loss and LogSoftmax nodes in order - if nll_loss_node_index < softmax_node_index: - del onnx_model.graph.node[softmax_node_index] - del onnx_model.graph.node[nll_loss_node_index] - else: - del onnx_model.graph.node[nll_loss_node_index] - del onnx_model.graph.node[softmax_node_index] - - probability_output_name = softmax_node.output[0] - node = onnx_model.graph.node.add() - inputs = ( - [softmax_node.input[0], label_input_name, weight_input_name] - if weight_input_name - else [softmax_node.input[0], label_input_name] - ) - node.CopyFrom( - onnx.helper.make_node( - "SparseSoftmaxCrossEntropy", - inputs, - [nll_loss_node.output[0], probability_output_name], - "nll_loss_node_" + str(nll_count), - ) - ) - - return onnx_model diff --git a/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py b/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py deleted file mode 100644 index f57f55d14eb1b..0000000000000 --- a/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py +++ /dev/null @@ -1,144 +0,0 @@ -import sys -import threading -import time - - -class OutputGrabber: - """ - Class used to grab standard output or another stream. - """ - - escape_char = "\b" - - def __init__(self, stream=None, threaded=False): - self.origstream = stream - self.threaded = threaded - if self.origstream is None: - self.origstream = sys.stdout - self.origstreamfd = self.origstream.fileno() - self.capturedtext = "" - # Create a pipe so the stream can be captured: - self.pipe_out, self.pipe_in = os.pipe() - - def __enter__(self): - self.start() - return self - - def __exit__(self, type, value, traceback): - self.stop() - - def start(self): - """ - Start capturing the stream data. - """ - self.capturedtext = "" - # Save a copy of the stream: - self.streamfd = os.dup(self.origstreamfd) - # Replace the original stream with our write pipe: - os.dup2(self.pipe_in, self.origstreamfd) - if self.threaded: - # Start thread that will read the stream: - self.workerThread = threading.Thread(target=self.readOutput) - self.workerThread.start() - # Make sure that the thread is running and os.read() has executed: - time.sleep(0.01) - - def stop(self): - """ - Stop capturing the stream data and save the text in `capturedtext`. - """ - # Print the escape character to make the readOutput method stop: - self.origstream.write(self.escape_char) - # Flush the stream to make sure all our data goes in before - # the escape character: - self.origstream.flush() - if self.threaded: - # wait until the thread finishes so we are sure that - # we have until the last character: - self.workerThread.join() - else: - self.readOutput() - # Close the pipe: - os.close(self.pipe_in) - os.close(self.pipe_out) - # Restore the original stream: - os.dup2(self.streamfd, self.origstreamfd) - # Close the duplicate stream: - os.close(self.streamfd) - - def readOutput(self): - """ - Read the stream data (one byte at a time) - and save the text in `capturedtext`. - """ - while True: - char = os.read(self.pipe_out, 1).decode(self.origstream.encoding) - if not char or self.escape_char in char: - break - self.capturedtext += char - - -import os # noqa: E402 -import unittest # noqa: E402 - -import numpy as np # noqa: E402, F401 -import torch # noqa: E402 -import torch.nn as nn # noqa: E402 -import torch.nn.functional as F # noqa: E402 - -from onnxruntime.capi import _pybind_state as torch_ort_eager # noqa: E402, F401 -from onnxruntime.training import optim, orttrainer, orttrainer_options # noqa: E402, F401 - - -def my_loss(x, target): - return F.nll_loss(F.log_softmax(x, dim=1), target) - - -class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, x, target): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - return my_loss(out, target) - - -class OrtEPTests(unittest.TestCase): - def test_external_graph_transformer_triggering(self): - input_size = 784 - hidden_size = 500 - num_classes = 10 - batch_size = 128 - model = NeuralNet(input_size, hidden_size, num_classes) - - model_desc = { - "inputs": [ - ("x", [batch_size, input_size]), - ( - "target", - [ - batch_size, - ], - ), - ], - "outputs": [("loss", [], True)], - } - optim_config = optim.SGDConfig() - opts = orttrainer.ORTTrainerOptions({"device": {"id": "cpu"}}) - model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - # because orttrainer is lazy initialized, feed in a random data to trigger the graph transformer - data = torch.rand(batch_size, input_size) - target = torch.randint(0, 10, (batch_size,)) - - with OutputGrabber() as out: - model.train_step(data, target) - assert "******************Trigger Customized Graph Transformer: MyGraphTransformer!" in out.capturedtext - - -if __name__ == "__main__": - unittest.main() diff --git a/orttraining/orttraining/test/external_transformer/test_exeternal_transformers/test_external_transformers.cc b/orttraining/orttraining/test/external_transformer/test_exeternal_transformers/test_external_transformers.cc deleted file mode 100644 index 00e933dd14914..0000000000000 --- a/orttraining/orttraining/test/external_transformer/test_exeternal_transformers/test_external_transformers.cc +++ /dev/null @@ -1,35 +0,0 @@ -#include "core/optimizer/rewrite_rule.h" -#include "orttraining/core/optimizer/graph_transformer_registry.h" -#include "onnx/defs/schema.h" -#include -#include - -namespace onnxruntime { -namespace training { - -class MyRewriteRule : public RewriteRule { - public: - MyRewriteRule() noexcept - : RewriteRule("MyRewriteRule") { - } - std::vector TargetOpTypes() const noexcept override { - return {}; - } - - private: - bool SatisfyCondition(const Graph& /*graph*/, const Node& /*node*/, const logging::Logger& /*logger*/) const override { - return true; - } - - Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/, const logging::Logger& /*logger*/) const override { - std::cout << "******************Trigger Customized Graph Transformer: MyGraphTransformer!" << std::endl; - return Status::OK(); - } -}; - -void RegisterTrainingExternalTransformers() { - ONNX_REGISTER_EXTERNAL_REWRITE_RULE(MyRewriteRule, Level1, true); -} - -} // namespace training -} // namespace onnxruntime diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 20b9354d85745..b774fec11cc8d 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -35,6 +35,7 @@ #ifdef ENABLE_TRITON #include "orttraining/core/optimizer/triton_fusion.h" #endif +#include "orttraining/core/optimizer/conv1d_replacement.h" #include @@ -1199,6 +1200,103 @@ TEST_P(QDQFusionTestsParameterized, CheckModelComposition) { ASSERT_EQ(op_to_count_post_fusion["com.microsoft.FakeQuant"], 1); } +TEST_F(GraphTransformationTests, Conv1dReplacement) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); + return Status::OK(); + }; + + for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { + for (auto group : {1, 2}) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); + auto out_channel = 64; + auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); + + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel / group, 1}, {-1.0f, 1.0f}); + auto* conv_output = builder.MakeOutput(); + + auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); + conv_node.AddAttribute("dilations", std::vector{1}); + conv_node.AddAttribute("kernel_shape", std::vector{1}); + conv_node.AddAttribute("strides", std::vector{1}); + conv_node.AddAttribute("group", static_cast(group)); + }; + + auto post_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 0); + // after graph transformation, the graph should have 1 squeeze, 2 split, group matmul, 1 concat + TEST_RETURN_IF_NOT(op_count_map["Squeeze"] == 1); + TEST_RETURN_IF_NOT(op_count_map["Split"] == 2); + TEST_RETURN_IF_NOT(op_count_map["MatMul"] == group); + TEST_RETURN_IF_NOT(op_count_map["Concat"] == 1); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } + } +} + +TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); + return Status::OK(); + }; + + // "group" is 3 so conv not replaced + for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); + auto out_channel = 64; + auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); + + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel / 3, 1}, {-1.0f, 1.0f}); + auto* conv_output = builder.MakeOutput(); + + auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); + conv_node.AddAttribute("dilations", std::vector{1}); + conv_node.AddAttribute("kernel_shape", std::vector{1}); + conv_node.AddAttribute("strides", std::vector{1}); + conv_node.AddAttribute("group", static_cast(3)); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker)); + } + + // "kernel_shape" is not 1 so conv not replaced + for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); + auto out_channel = 64; + auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); + + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel, 1}, {-1.0f, 1.0f}); + auto* conv_output = builder.MakeOutput(); + + auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); + conv_node.AddAttribute("dilations", std::vector{1}); + conv_node.AddAttribute("kernel_shape", std::vector{2}); + conv_node.AddAttribute("strides", std::vector{1}); + conv_node.AddAttribute("group", static_cast(1)); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker)); + } +} + INSTANTIATE_TEST_SUITE_P( QDQFusionTests, QDQFusionTestsParameterized, diff --git a/orttraining/orttraining/test/python/_test_commons.py b/orttraining/orttraining/test/python/_test_commons.py index 1413d59096832..fb7e62551de63 100644 --- a/orttraining/orttraining/test/python/_test_commons.py +++ b/orttraining/orttraining/test/python/_test_commons.py @@ -1,26 +1,7 @@ -import copy -import math import os import subprocess import sys -import numpy as np -import onnx -import torch -from numpy.testing import assert_allclose - -import onnxruntime -from onnxruntime.training import _utils, optim - - -def _single_run(execution_file, scenario, checkopint_dir=None): - cmd = [sys.executable, execution_file] - if scenario: - cmd += ["--scenario", scenario] - if checkopint_dir: - cmd += ["--checkpoint_dir", checkopint_dir] - assert subprocess.call(cmd) == 0 - def is_windows(): return sys.platform.startswith("win") @@ -46,197 +27,3 @@ def run_subprocess(args, cwd=None, capture=False, dll_path=None, shell=False, en if log: log.debug("Subprocess completed. Return code=" + str(completed_process.returncode)) return completed_process - - -def legacy_constant_lr_scheduler(global_step, initial_lr, total_steps, warmup): - num_warmup_steps = warmup * total_steps - if global_step < num_warmup_steps: - new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps)) - else: - new_lr = initial_lr - return new_lr - - -def legacy_cosine_lr_scheduler(global_step, initial_lr, total_steps, warmup, cycles): - num_warmup_steps = warmup * total_steps - if global_step < num_warmup_steps: - new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps)) - else: - progress = float(global_step - num_warmup_steps) / float(max(1, total_steps - num_warmup_steps)) - new_lr = initial_lr * max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(cycles) * 2.0 * progress))) - return new_lr - - -def legacy_linear_lr_scheduler(global_step, initial_lr, total_steps, warmup): - num_warmup_steps = warmup * total_steps - if global_step < num_warmup_steps: - new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps)) - else: - new_lr = initial_lr * max(0.0, float(total_steps - global_step) / float(max(1, total_steps - num_warmup_steps))) - return new_lr - - -def legacy_poly_lr_scheduler(global_step, initial_lr, total_steps, warmup, power, lr_end): - num_warmup_steps = warmup * total_steps - if global_step < num_warmup_steps: - new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps)) - elif global_step > total_steps: - new_lr = lr_end - else: - lr_range = initial_lr - lr_end - decay_steps = total_steps - num_warmup_steps - pct_remaining = 1 - (global_step - num_warmup_steps) / decay_steps - decay = lr_range * pct_remaining**power + lr_end - new_lr = decay - return new_lr - - -def generate_dummy_optim_state(model, optimizer): - np.random.seed(0) - if not (isinstance(optimizer, (optim.AdamConfig, optim.LambConfig))): - return dict() - - moment_keys = ["Moment_1", "Moment_2"] - uc_key = "Update_Count" - step_key = "Step" - shared_state_key = "shared_optimizer_state" - - optim_state = dict() - weight_shape_map = dict() - if isinstance(model, torch.nn.Module): - weight_shape_map = {name: param.size() for name, param in model.named_parameters()} - elif isinstance(model, onnx.ModelProto): - weight_shape_map = {n.name: n.dims for n in model.graph.initializer} - else: - raise ValueError("'model' must be either 'torch.nn.Module' or 'onnx.ModelProto'") - - for weight_name, weight_shape in weight_shape_map.items(): - per_weight_state = dict() - for moment in moment_keys: - per_weight_state[moment] = np.random.uniform(-2, 2, weight_shape).astype(np.float32) - if isinstance(optimizer, optim.AdamConfig): - per_weight_state[uc_key] = np.full([1], 5, dtype=np.int64) - optim_state[weight_name] = copy.deepcopy(per_weight_state) - if isinstance(optimizer, optim.LambConfig): - step_val = np.full([1], 5, dtype=np.int64) - optim_state[shared_state_key] = {step_key: step_val} - return {"optimizer": optim_state, "trainer_options": {"optimizer_name": optimizer.name}} - - -def _load_pytorch_transformer_model(device, dynamic_axes=False, legacy_api=False, data_dir=None): - # Loads external Pytorch TransformerModel into utils - root = "samples" - if not os.path.exists(root): - root = os.path.normpath( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "..", "samples") - ) - if not os.path.exists(root): - raise FileNotFoundError("Unable to find folder 'samples', tried %r." % root) - pytorch_transformer_path = os.path.join(root, "python", "training", "orttrainer", "pytorch_transformer") - pt_model_path = os.path.join(pytorch_transformer_path, "pt_model.py") - pt_model = _utils.import_module_from_file(pt_model_path) - ort_utils_path = os.path.join(pytorch_transformer_path, "ort_utils.py") - ort_utils = _utils.import_module_from_file(ort_utils_path) - utils_path = os.path.join(pytorch_transformer_path, "utils.py") - utils = _utils.import_module_from_file(utils_path) - - # Modeling - model = pt_model.TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device) - my_loss = ort_utils.my_loss - if legacy_api: - if dynamic_axes: - model_desc = ort_utils.legacy_transformer_model_description_dynamic_axes() - else: - model_desc = ort_utils.legacy_transformer_model_description() - else: - if dynamic_axes: - model_desc = ort_utils.transformer_model_description_dynamic_axes() - else: - model_desc = ort_utils.transformer_model_description() - - # Preparing data - train_data, val_data, test_data = utils.prepare_data(device, 20, 20, data_dir) - return model, model_desc, my_loss, utils.get_batch, train_data, val_data, test_data - - -def generate_random_input_from_bart_model_desc(desc, seed=1, device="cuda:0"): - """Generates a sample input for the BART model using the model desc""" - - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - dtype = torch.int64 - vocab_size = 30528 - sample_input = [] - for _index, input in enumerate(desc["inputs"]): - size = [] - for s in input[1]: - if isinstance(s, (int)): - size.append(s) - else: - size.append(1) - sample_input.append(torch.randint(0, vocab_size, tuple(size), dtype=dtype).to(device)) - return sample_input - - -def _load_bart_model(): - bart_onnx_model_path = os.path.join("testdata", "bart_tiny.onnx") - model = onnx.load(bart_onnx_model_path) - batch = 2 - seq_len = 1024 - model_desc = { - "inputs": [ - ( - "src_tokens", - [batch, seq_len], - ), - ( - "prev_output_tokens", - [batch, seq_len], - ), - ( - "target", - [batch * seq_len], - ), - ], - "outputs": [("loss", [], True)], - } - - return model, model_desc - - -def assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint, reshape_states=False): - """Assert that the two ORTTrainer (hierarchical) state dictionaries are very close for all states""" - - assert ("model" in state_dict_pre_checkpoint) == ("model" in state_dict_post_checkpoint) - assert ("optimizer" in state_dict_pre_checkpoint) == ("optimizer" in state_dict_post_checkpoint) - - if "model" in state_dict_pre_checkpoint: - for model_state_key in state_dict_pre_checkpoint["model"]["full_precision"]: - if reshape_states: - assert_allclose( - state_dict_pre_checkpoint["model"]["full_precision"][model_state_key], - state_dict_post_checkpoint["model"]["full_precision"][model_state_key].reshape( - state_dict_pre_checkpoint["model"]["full_precision"][model_state_key].shape - ), - ) - else: - assert_allclose( - state_dict_pre_checkpoint["model"]["full_precision"][model_state_key], - state_dict_post_checkpoint["model"]["full_precision"][model_state_key], - ) - - if "optimizer" in state_dict_pre_checkpoint: - for model_state_key in state_dict_pre_checkpoint["optimizer"]: - for optimizer_state_key in state_dict_pre_checkpoint["optimizer"][model_state_key]: - if reshape_states: - assert_allclose( - state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key], - state_dict_post_checkpoint["optimizer"][model_state_key][optimizer_state_key].reshape( - state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key].shape - ), - ) - else: - assert_allclose( - state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key], - state_dict_post_checkpoint["optimizer"][model_state_key][optimizer_state_key], - ) diff --git a/orttraining/orttraining/test/python/_test_helpers.py b/orttraining/orttraining/test/python/_test_helpers.py index a9a4c7b1cc2ef..8f2a18b5ec00b 100644 --- a/orttraining/orttraining/test/python/_test_helpers.py +++ b/orttraining/orttraining/test/python/_test_helpers.py @@ -1,30 +1,11 @@ import copy import os -import numpy as np import torch from numpy.testing import assert_allclose -from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer -from onnxruntime.training import orttrainer - -try: - from onnxruntime.training.ortmodule import ORTModule - from onnxruntime.training.ortmodule._fallback import ORTModuleInitException - from onnxruntime.training.ortmodule._graph_execution_manager_factory import ( # noqa: F401 - GraphExecutionManagerFactory, - ) -except ImportError: - # Some pipelines do not contain ORTModule - pass -except Exception as e: - from onnxruntime.training.ortmodule._fallback import ORTModuleInitException - - if isinstance(e, ORTModuleInitException): - # ORTModule is present but not ready to run - # That is OK because this file is also used by ORTTrainer tests - pass - raise +from onnxruntime.training.ortmodule import ORTModule +from onnxruntime.training.ortmodule._graph_execution_manager_factory import GraphExecutionManagerFactory # noqa: F401 def is_all_or_nothing_fallback_enabled(model, policy=None): @@ -66,103 +47,6 @@ def assert_model_outputs(output_a, output_b, verbose=False, rtol=1e-7, atol=0): assert_allclose(output_a, output_b, rtol=rtol, atol=atol, err_msg="Model output value mismatch") -def assert_onnx_weights(model_a, model_b, verbose=False, rtol=1e-7, atol=0): - r"""Asserts whether weight difference between models a and b differences are within specified tolerance - - Compares the weights of two different ONNX models (model_a and model_b) - and raises AssertError when they diverge by more than atol or rtol - - Args: - model_a, model_b (ORTTrainer): Two instances of ORTTrainer with the same model structure - verbose (bool, default is False): if True, prints absolute difference for each weight - rtol (float, default is 1e-7): Max relative difference - atol (float, default is 1e-4): Max absolute difference - """ - assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, orttrainer.ORTTrainer) - state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b._training_session.get_state() - assert len(state_dict_a.items()) == len(state_dict_b.items()) - _assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol) - - -def assert_legacy_onnx_weights(model_a, model_b, verbose=False, rtol=1e-7, atol=0): - r"""Asserts whether weight difference between models a and b differences are within specified tolerance - - Compares the weights of a legacy model model_a and experimental model_b model - and raises AssertError when they diverge by more than atol or rtol. - - Args: - model_a (ORTTrainer): Instance of legacy ORTTrainer - model_b (ORTTrainer): Instance of experimental ORTTrainer - verbose (bool, default is False): if True, prints absolute difference for each weight. - rtol (float, default is 1e-7): Max relative difference - atol (float, default is 1e-4): Max absolute difference - """ - assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, Legacy_ORTTrainer) - state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b.session.get_state() - assert len(state_dict_a.items()) == len(state_dict_b.items()) - _assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol) - - -def _assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol): - r"""Asserts whether dicts a and b value differences are within specified tolerance - - Compares the weights of two model's state_dict dicts and raises AssertError - when they diverge by more than atol or rtol - - Args: - model_a (ORTTrainer): Instance of legacy ORTTrainer - model_b (ORTTrainer): Instance of experimental ORTTrainer - verbose (bool, default is False): if True, prints absolute difference for each weight. - rtol (float, default is 1e-7): Max relative difference - atol (float, default is 1e-4): Max absolute difference - """ - - for (a_name, a_val), (_b_name, b_val) in zip(state_dict_a.items(), state_dict_b.items()): - np_a_vals = np.array(a_val).flatten() - np_b_vals = np.array(b_val).flatten() - assert np_a_vals.shape == np_b_vals.shape - if verbose: - print(f"Weight name: {a_name}: absolute difference: {np.abs(np_a_vals-np_b_vals).max()}") - assert_allclose(a_val, b_val, rtol=rtol, atol=atol, err_msg=f"Weight mismatch for {a_name}") - - -def assert_optim_state(expected_state, actual_state, rtol=1e-7, atol=0): - r"""Asserts whether optimizer state differences are within specified tolerance - - Compares the expected and actual optimizer states of dicts and raises AssertError - when they diverge by more than atol or rtol. - The optimizer dict is of the form: - model_weight_name: - { - "Moment_1": moment1_tensor, - "Moment_2": moment2_tensor, - "Update_Count": update_tensor # if optimizer is adam, absent otherwise - }, - ... - "shared_optimizer_state": # if optimizer is shared, absent otherwise. - So far, only lamb optimizer uses this. - { - "step": step_tensor # int array of size 1 - } - - Args: - expected_state (dict(dict())): Expected optimizer state - actual_state (dict(dict())): Actual optimizer state - rtol (float, default is 1e-7): Max relative difference - atol (float, default is 0): Max absolute difference - """ - assert expected_state.keys() == actual_state.keys() - for param_name, a_state in actual_state.items(): - for k, v in a_state.items(): - assert_allclose( - v, - expected_state[param_name][k], - rtol=rtol, - atol=atol, - err_msg=f"Optimizer state mismatch for param {param_name}, key {k}", - ) - - def is_dynamic_axes(model): # Check inputs for inp in model._torch_module._execution_manager(model._is_training())._onnx_models.optimized_model.graph.input: diff --git a/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py b/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py deleted file mode 100644 index d5298cf8e860e..0000000000000 --- a/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py +++ /dev/null @@ -1,325 +0,0 @@ -import os -import unittest - -import torch -import torch.nn as nn -from orttraining_test_bert_postprocess import postprocess_model -from orttraining_test_data_loader import create_ort_test_dataloader -from orttraining_test_transformers import BertForPreTraining, BertModelTest -from orttraining_test_utils import map_optimizer_attributes - -import onnxruntime -from onnxruntime.capi.ort_trainer import ( # noqa: F401 - IODescription, - LossScaler, - ModelDescription, - ORTTrainer, - generate_sample, -) - -torch.manual_seed(1) -onnxruntime.set_seed(1) - - -class Test_PostPasses(unittest.TestCase): # noqa: N801 - def get_onnx_model( - self, model, model_desc, inputs, device, _enable_internal_postprocess=True, _extra_postprocess=None - ): - lr_desc = IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ) - model = ORTTrainer( - model, - None, - model_desc, - "LambOptimizer", - map_optimizer_attributes, - lr_desc, - device, - world_rank=0, - world_size=1, - _opset_version=14, - _enable_internal_postprocess=_enable_internal_postprocess, - _extra_postprocess=_extra_postprocess, - ) - - model.train_step(*inputs) - return model.onnx_model_ - - def count_all_nodes(self, model): - return len(model.graph.node) - - def count_nodes(self, model, node_type): - count = 0 - for node in model.graph.node: - if node.op_type == node_type: - count += 1 - return count - - def find_nodes(self, model, node_type): - nodes = [] - for node in model.graph.node: - if node.op_type == node_type: - nodes.append(node) - return nodes - - def get_name(self, name): - if os.path.exists(name): - return name - rel = os.path.join("testdata", name) - if os.path.exists(rel): - return rel - this = os.path.dirname(__file__) - data = os.path.join(this, "..", "..", "..", "..", "onnxruntime", "test", "testdata") - res = os.path.join(data, name) - if os.path.exists(res): - return res - raise FileNotFoundError(f"Unable to find '{name}' or '{rel}' or '{res}'") - - def test_layer_norm(self): - class LayerNormNet(nn.Module): - def __init__(self, target): - super().__init__() - self.ln_1 = nn.LayerNorm(10) - self.loss = nn.CrossEntropyLoss() - self.target = target - - def forward(self, x): - output1 = self.ln_1(x) - loss = self.loss(output1, self.target) - return loss, output1 - - device = torch.device("cpu") - target = torch.ones(20, 10, 10, dtype=torch.int64).to(device) - model = LayerNormNet(target) - input = torch.randn(20, 5, 10, 10, dtype=torch.float32).to(device) - - input_desc = IODescription("input", [], "float32") - output0_desc = IODescription("output0", [], "float32") - output1_desc = IODescription("output1", [20, 5, 10, 10], "float32") - model_desc = ModelDescription([input_desc], [output0_desc, output1_desc]) - - learning_rate = torch.tensor([1.0000000e00]).to(device) - input_args = [input, learning_rate] - - onnx_model = self.get_onnx_model(model, model_desc, input_args, device) - - count_layer_norm = self.count_nodes(onnx_model, "LayerNormalization") - count_nodes = self.count_all_nodes(onnx_model) - - assert count_layer_norm == 0 - assert count_nodes == 3 - - def test_expand(self): - class ExpandNet(nn.Module): - def __init__(self, target): - super().__init__() - self.loss = nn.CrossEntropyLoss() - self.target = target - self.linear = torch.nn.Linear(2, 2) - - def forward(self, x, x1): - output = x.expand_as(x1) - output = self.linear(output) - output = output + output - loss = self.loss(output, self.target) - return loss, output - - device = torch.device("cpu") - target = torch.ones(5, 5, 2, dtype=torch.int64).to(device) - model = ExpandNet(target).to(device) - - x = torch.randn(5, 3, 1, 2, dtype=torch.float32).to(device) - x1 = torch.randn(5, 3, 5, 2, dtype=torch.float32).to(device) - - input0_desc = IODescription("x", [5, 3, 1, 2], "float32") - input1_desc = IODescription("x1", [5, 3, 5, 2], "float32") - output0_desc = IODescription("output0", [], "float32") - output1_desc = IODescription("output1", [5, 3, 5, 2], "float32") - model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc]) - - learning_rate = torch.tensor([1.0000000e00]).to(device) - input_args = [x, x1, learning_rate] - - onnx_model = self.get_onnx_model(model, model_desc, input_args, device) - - # check that expand output has shape - expand_nodes = self.find_nodes(onnx_model, "Expand") - assert len(expand_nodes) == 1 - - model_info = onnx_model.graph.value_info - assert model_info[0].name == expand_nodes[0].output[0] - assert model_info[0].type == onnx_model.graph.input[1].type - - def test_bert(self): - device = torch.device("cpu") - - model_tester = BertModelTest.BertModelTester(self) - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = model_tester.prepare_config_and_inputs() - - model = BertForPreTraining(config=config) - model.eval() - - loss, prediction_scores, seq_relationship_score = model( - input_ids, - attention_mask=input_mask, - token_type_ids=token_type_ids, - masked_lm_labels=token_labels, - next_sentence_label=sequence_labels, - ) - - model_desc = ModelDescription( - [ - model_tester.input_ids_desc, - model_tester.attention_mask_desc, - model_tester.token_type_ids_desc, - model_tester.masked_lm_labels_desc, - model_tester.next_sentence_label_desc, - ], - [model_tester.loss_desc, model_tester.prediction_scores_desc, model_tester.seq_relationship_scores_desc], - ) - - from collections import namedtuple - - MyArgs = namedtuple( - "MyArgs", "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len" - ) - args = MyArgs( - local_rank=0, - world_size=1, - max_steps=100, - learning_rate=0.00001, - warmup_proportion=0.01, - batch_size=13, - seq_len=7, - ) - - dataset_len = 100 - dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device) - learning_rate = torch.tensor(1.0e0, dtype=torch.float32).to(device) - for b in dataloader: - batch = b - break - learning_rate = torch.tensor([1.00e00]).to(device) - inputs = [*batch, learning_rate] - - onnx_model = self.get_onnx_model(model, model_desc, inputs, device, _extra_postprocess=postprocess_model) - - self._bert_helper(onnx_model) - - def _bert_helper(self, onnx_model): - # count layer_norm - count_layer_norm = self.count_nodes(onnx_model, "LayerNormalization") - assert count_layer_norm == 0 - - # get expand node and check output shape - expand_nodes = self.find_nodes(onnx_model, "Expand") - assert len(expand_nodes) == 1 - - model_info = onnx_model.graph.value_info - assert model_info[0].name == expand_nodes[0].output[0] - assert model_info[0].type == onnx_model.graph.input[0].type - - def test_extra_postpass(self): - def postpass_replace_first_add_with_sub(model): - # this post pass replaces the first Add node with Sub in the model. - # Previous graph - # (subgraph 1) (subgraph 2) - # | | - # | | - # |________ ________| - # | | - # Add - # | - # (subgraph 3) - # - # Post graph - # (subgraph 1) (subgraph 2) - # | | - # | | - # |________ ________| - # | | - # Sub - # | - # (subgraph 3) - add_nodes = [n for n in model.graph.node if n.op_type == "Add"] - add_nodes[0].op_type = "Sub" - - class MultiAdd(nn.Module): - def __init__(self, target): - super().__init__() - self.loss = nn.CrossEntropyLoss() - self.target = target - self.linear = torch.nn.Linear(2, 2, bias=False) - - def forward(self, x, x1): - output = x + x1 - output = output + x - output = output + x1 - output = self.linear(output) - loss = self.loss(output, self.target) - return loss, output - - device = torch.device("cpu") - target = torch.ones(5, 2, dtype=torch.int64).to(device) - model = MultiAdd(target).to(device) - - x = torch.randn(5, 5, 2, dtype=torch.float32).to(device) - x1 = torch.randn(5, 5, 2, dtype=torch.float32).to(device) - - input0_desc = IODescription("x", [5, 5, 2], "float32") - input1_desc = IODescription("x1", [5, 5, 2], "float32") - output0_desc = IODescription("output0", [], "float32") - output1_desc = IODescription("output1", [5, 5, 2], "float32") - model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc]) - - learning_rate = torch.tensor([1.0000000e00]).to(device) - input_args = [x, x1, learning_rate] - - onnx_model = self.get_onnx_model( - model, model_desc, input_args, device, _extra_postprocess=postpass_replace_first_add_with_sub - ) - - # check that extra postpass is called, and called only once. - add_nodes = self.find_nodes(onnx_model, "Add") - sub_nodes = self.find_nodes(onnx_model, "Sub") - assert len(add_nodes) == 2 - assert len(sub_nodes) == 1 - - unprocessed_onnx_model = self.get_onnx_model( - model, model_desc, input_args, device, _extra_postprocess=None, _enable_internal_postprocess=False - ) - # check that the model is unchanged. - add_nodes = self.find_nodes(unprocessed_onnx_model, "Add") - sub_nodes = self.find_nodes(unprocessed_onnx_model, "Sub") - assert len(add_nodes) == 3 - assert len(sub_nodes) == 0 - - processed_onnx_model = self.get_onnx_model( - unprocessed_onnx_model, - model_desc, - input_args, - device, - _extra_postprocess=postpass_replace_first_add_with_sub, - ) - # check that extra postpass is called, and called only once. - add_nodes = self.find_nodes(processed_onnx_model, "Add") - sub_nodes = self.find_nodes(processed_onnx_model, "Sub") - assert len(add_nodes) == 2 - assert len(sub_nodes) == 1 - - -if __name__ == "__main__": - unittest.main(module=__name__, buffer=True) diff --git a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py index 0e7e9d23ee627..5341cd053ac18 100644 --- a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py +++ b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py @@ -43,7 +43,7 @@ def run_ortmodule_ops_tests(cwd, log, transformers_cache): env = get_env_with_transformers_cache(transformers_cache) - command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_onnx_ops_ortmodule.py"] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ortmodule_onnx_ops.py"] run_subprocess(command, cwd=cwd, log=log, env=env).check_returncode() @@ -146,7 +146,7 @@ def run_data_sampler_tests(cwd, log): def run_hooks_tests(cwd, log): log.debug("Running: Data hooks tests") - command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_hooks.py"] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ortmodule_hooks.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() diff --git a/orttraining/orttraining/test/python/orttraining_run_bert_pretrain.py b/orttraining/orttraining/test/python/orttraining_run_bert_pretrain.py deleted file mode 100644 index eea733684f140..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_run_bert_pretrain.py +++ /dev/null @@ -1,801 +0,0 @@ -# ================== -import dataclasses -import datetime -import glob -import json -import logging -import os -import random -import shutil -import unittest -from concurrent.futures import ProcessPoolExecutor -from dataclasses import dataclass, field -from typing import Any, Dict, Optional - -import h5py -import numpy as np -import torch -import torch.distributed as dist -from torch.utils.data import DataLoader, Dataset, RandomSampler -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm -from transformers import BertConfig, BertForPreTraining, HfArgumentParser - -import onnxruntime as ort - -# need to override torch.onnx.symbolic_opset12.nll_loss to handle ignore_index == -100 cases. -# the fix for ignore_index == -100 cases is already in pytorch master. -# however to use current torch master is causing computation changes in many tests. -# eventually we will use pytorch with fixed nll_loss once computation -# issues are understood and solved. -import onnxruntime.capi.pt_patch -from onnxruntime.training import amp, optim, orttrainer -from onnxruntime.training.checkpoint import aggregate_checkpoints -from onnxruntime.training.optim import LinearWarmupLRScheduler, PolyWarmupLRScheduler # noqa: F401 - -# we cannot make full convergence run in nightly pipeling because of its timeout limit, -# max_steps is still needed to calculate learning rate. force_to_stop_max_steps is used to -# terminate the training before the pipeline run hit its timeout. -force_to_stop_max_steps = 2500 - -logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO -) -logger = logging.getLogger(__name__) - - -def get_rank(): - if not dist.is_available(): - return 0 - if not dist.is_initialized(): - return 0 - return dist.get_rank() - - -def is_main_process(args): - if hasattr(args, "world_rank"): - return args.world_rank in [-1, 0] - else: - return get_rank() == 0 - - -def bert_model_description(config): - vocab_size = config.vocab_size - new_model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", "max_seq_len_in_batch"], - ), - ( - "token_type_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "masked_lm_labels", - ["batch", "max_seq_len_in_batch"], - ), - ( - "next_sentence_label", - [ - "batch", - ], - ), - ], - "outputs": [ - ("loss", [], True), - ( - "prediction_scores", - ["batch", "max_seq_len_in_batch", vocab_size], - ), - ( - "seq_relationship_scores", - ["batch", 2], - ), - ], - } - return new_model_desc - - -def create_pretraining_dataset(input_file, max_pred_length, args): - train_data = pretraining_dataset(input_file=input_file, max_pred_length=max_pred_length) - train_sampler = RandomSampler(train_data) - train_dataloader = DataLoader( - train_data, sampler=train_sampler, batch_size=args.train_batch_size * args.n_gpu, num_workers=0, pin_memory=True - ) - return train_dataloader, input_file - - -class pretraining_dataset(Dataset): # noqa: N801 - def __init__(self, input_file, max_pred_length): - logger.info("pretraining_dataset: %s, max_pred_length: %d", input_file, max_pred_length) - self.input_file = input_file - self.max_pred_length = max_pred_length - f = h5py.File(input_file, "r") - keys = [ - "input_ids", - "input_mask", - "segment_ids", - "masked_lm_positions", - "masked_lm_ids", - "next_sentence_labels", - ] - self.inputs = [np.asarray(f[key][:]) for key in keys] - f.close() - - def __len__(self): - "Denotes the total number of samples" - return len(self.inputs[0]) - - def __getitem__(self, index): - [input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids, next_sentence_labels] = [ - torch.from_numpy(input[index].astype(np.int64)) - if indice < 5 - else torch.from_numpy(np.asarray(input[index].astype(np.int64))) - for indice, input in enumerate(self.inputs) - ] - - # HF model use default ignore_index value (-100) for CrossEntropyLoss - masked_lm_labels = torch.ones(input_ids.shape, dtype=torch.long) * -100 - index = self.max_pred_length - # store number of masked tokens in index - padded_mask_indices = (masked_lm_positions == 0).nonzero() - if len(padded_mask_indices) != 0: - index = padded_mask_indices[0].item() - masked_lm_labels[masked_lm_positions[:index]] = masked_lm_ids[:index] - return [input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels] - - -import argparse # noqa: E402 - - -def parse_arguments(): - parser = argparse.ArgumentParser() - - # batch size test config parameters - parser.add_argument( - "--enable_mixed_precision", - default=False, - action="store_true", - help="Whether to use 16-bit float precision instead of 32-bit", - ) - - parser.add_argument( - "--sequence_length", - default=512, - type=int, - help="The maximum total input sequence length after WordPiece tokenization. \n" - "Sequences longer than this will be truncated, and sequences shorter \n" - "than this will be padded.", - ) - parser.add_argument( - "--max_predictions_per_seq", default=80, type=int, help="The maximum total of masked tokens in input sequence" - ) - parser.add_argument("--max_batch_size", default=32, type=int, help="Total batch size for training.") - - parser.add_argument("--gelu_recompute", default=False, action="store_true") - - parser.add_argument("--attn_dropout_recompute", default=False, action="store_true") - - parser.add_argument("--transformer_layer_recompute", default=False, action="store_true") - - args = parser.parse_args() - return args - - -@dataclass -class PretrainArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. - """ - - input_dir: str = field( - default=None, metadata={"help": "The input data dir. Should contain .hdf5 files for the task"} - ) - - bert_model: str = field( - default=None, - metadata={ - "help": "Bert pre-trained model selected in the list: bert-base-uncased, \ - bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." - }, - ) - - output_dir: str = field( - default=None, metadata={"help": "The output directory where the model checkpoints will be written."} - ) - - cache_dir: str = field( - default="/tmp/bert_pretrain/", - metadata={"help": "The output directory where the model checkpoints will be written."}, - ) - max_seq_length: Optional[int] = field( - default=512, - metadata={ - "help": "The maximum total input sequence length after tokenization. Sequences longer \ - than this will be truncated, sequences shorter will be padded." - }, - ) - - max_predictions_per_seq: Optional[int] = field( - default=80, metadata={"help": "The maximum total of masked tokens in input sequence."} - ) - - train_batch_size: Optional[int] = field(default=32, metadata={"help": "Batch size for training."}) - - learning_rate: Optional[float] = field(default=5e-5, metadata={"help": "The initial learning rate for Lamb."}) - - num_train_epochs: Optional[float] = field( - default=3.0, metadata={"help": "Total number of training epochs to perform."} - ) - - max_steps: Optional[float] = field(default=1000, metadata={"help": "Total number of training steps to perform."}) - - warmup_proportion: Optional[float] = field( - default=0.01, - metadata={ - "help": "Proportion of training to perform linear learning rate warmup for. \ - E.g., 0.1 = 10%% of training." - }, - ) - - local_rank: Optional[int] = field(default=-1, metadata={"help": "local_rank for distributed training on gpus."}) - - world_rank: Optional[int] = field(default=-1) - - world_size: Optional[int] = field(default=1) - - seed: Optional[int] = field(default=42, metadata={"help": "random seed for initialization."}) - - gradient_accumulation_steps: Optional[int] = field( - default=1, metadata={"help": "Number of updates steps to accumualte before performing a backward/update pass."} - ) - - fp16: bool = field(default=False, metadata={"help": "Whether to use 16-bit float precision instead of 32-bit."}) - - gelu_recompute: bool = field( - default=False, metadata={"help": "Whether to enable recomputing Gelu activation output to save memory."} - ) - attn_dropout_recompute: bool = field( - default=False, metadata={"help": "Whether to enable recomputing attention dropout to save memory."} - ) - transformer_layer_recompute: bool = field( - default=False, metadata={"help": "Whether to enable recomputing transformer layerwise to save memory."} - ) - - loss_scale: Optional[float] = field( - default=0.0, metadata={"help": "Loss scaling, positive power of 2 values can improve fp16 convergence."} - ) - - deepspeed_zero_stage: Optional[int] = field(default=0, metadata={"help": "Deepspeed Zero Stage. 0 => disabled"}) - - log_freq: Optional[float] = field(default=1.0, metadata={"help": "frequency of logging loss."}) - - checkpoint_activations: bool = field(default=False, metadata={"help": "Whether to use gradient checkpointing."}) - - resume_from_checkpoint: bool = field( - default=False, metadata={"help": "Whether to resume training from checkpoint."} - ) - - resume_step: Optional[int] = field(default=-1, metadata={"help": "Step to resume training from."}) - - num_steps_per_checkpoint: Optional[int] = field( - default=100, metadata={"help": "Number of update steps until a model checkpoint is saved to disk."} - ) - - save_checkpoint: Optional[bool] = field( - default=False, metadata={"help": "Enable for saving a model checkpoint to disk."} - ) - - init_state_dict: Optional[dict] = field(default=None, metadata={"help": "State to load before training."}) - - phase2: bool = field(default=False, metadata={"help": "Whether to train with seq len 512."}) - - allreduce_post_accumulation: bool = field( - default=False, metadata={"help": "Whether to do allreduces during gradient accumulation steps."} - ) - - allreduce_post_accumulation_fp16: bool = field( - default=False, metadata={"help": "Whether to do fp16 allreduce post accumulation."} - ) - - accumulate_into_fp16: bool = field(default=False, metadata={"help": "Whether to use fp16 gradient accumulators."}) - - phase1_end_step: Optional[int] = field( - default=7038, metadata={"help": "Whether to use fp16 gradient accumulators."} - ) - - tensorboard_dir: Optional[str] = field( - default=None, - ) - - schedule: Optional[str] = field( - default="warmup_poly", - ) - - # this argument is test specific. to run a full bert model will take too long to run. instead, we reduce - # number of hidden layers so that it can show convergence to an extend to help detect any regression. - force_num_hidden_layers: Optional[int] = field( - default=None, metadata={"help": "Whether to use fp16 gradient accumulators."} - ) - - def to_json_string(self): - """ - Serializes this instance to a JSON string. - """ - return json.dumps(dataclasses.asdict(self), indent=2) - - def to_sanitized_dict(self) -> Dict[str, Any]: - """ - Sanitized serialization to use with TensorBoard`s hparams - """ - d = dataclasses.asdict(self) - valid_types = [bool, int, float, str, torch.Tensor] - return {k: v if type(v) in valid_types else str(v) for k, v in d.items()} - - -def setup_training(args): - assert torch.cuda.is_available() - - if args.local_rank == -1: - args.local_rank = 0 - args.world_rank = 0 - - print("args.local_rank: ", args.local_rank) - torch.cuda.set_device(args.local_rank) - device = torch.device("cuda", args.local_rank) - args.n_gpu = 1 - - if args.gradient_accumulation_steps < 1: - raise ValueError( - f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps}, should be >= 1" - ) - if args.train_batch_size % args.gradient_accumulation_steps != 0: - raise ValueError( - "Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible".format( - args.gradient_accumulation_steps, args.train_batch_size - ) - ) - - # args.train_batch_size is per global step (optimization step) batch size - # now make it a per gpu batch size - args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps - args.train_batch_size = args.train_batch_size // args.world_size - - logger.info("setup_training: args.train_batch_size = %d", args.train_batch_size) - return device, args - - -def setup_torch_distributed(world_rank, world_size): - os.environ["RANK"] = str(world_rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12345" - torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=world_rank) - return - - -def prepare_model(args, device): - config = BertConfig.from_pretrained(args.bert_model, cache_dir=args.cache_dir) - - # config.num_hidden_layers = 12 - if args.force_num_hidden_layers: - logger.info("Modifying model config with num_hidden_layers to %d", args.force_num_hidden_layers) - config.num_hidden_layers = args.force_num_hidden_layers - - model = BertForPreTraining(config) - if args.init_state_dict is not None: - model.load_state_dict(args.init_state_dict) - model_desc = bert_model_description(config) - - lr_scheduler = LinearWarmupLRScheduler(total_steps=int(args.max_steps), warmup=args.warmup_proportion) - - loss_scaler = amp.DynamicLossScaler() if args.fp16 else None - - options = orttrainer.ORTTrainerOptions( - { - "batch": {"gradient_accumulation_steps": args.gradient_accumulation_steps}, - "device": {"id": str(device)}, - "mixed_precision": {"enabled": args.fp16, "loss_scaler": loss_scaler}, - "graph_transformer": { - "attn_dropout_recompute": args.attn_dropout_recompute, - "gelu_recompute": args.gelu_recompute, - "transformer_layer_recompute": args.transformer_layer_recompute, - }, - "debug": { - "deterministic_compute": True, - }, - "utils": {"grad_norm_clip": True}, - "distributed": { - "world_rank": max(0, args.local_rank), - "world_size": args.world_size, - "local_rank": max(0, args.local_rank), - "allreduce_post_accumulation": args.allreduce_post_accumulation, - "deepspeed_zero_optimization": {"stage": args.deepspeed_zero_stage}, - "enable_adasum": False, - }, - "lr_scheduler": lr_scheduler, - } - ) - - param_optimizer = list(model.named_parameters()) - no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] - params = [ - { - "params": [n for n, p in param_optimizer if any(no_decay_key in n for no_decay_key in no_decay_keys)], - "alpha": 0.9, - "beta": 0.999, - "lambda": 0.0, - "epsilon": 1e-6, - }, - { - "params": [n for n, p in param_optimizer if not any(no_decay_key in n for no_decay_key in no_decay_keys)], - "alpha": 0.9, - "beta": 0.999, - "lambda": 0.0, - "epsilon": 1e-6, - }, - ] - - optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) - model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=options) - - return model - - -def get_data_file(f_id, world_rank, world_size, files): - num_files = len(files) - if world_size > num_files: - remainder = world_size % num_files - return files[(f_id * world_size + world_rank + remainder * f_id) % num_files] - elif world_size > 1: - return files[(f_id * world_size + world_rank) % num_files] - else: - return files[f_id % num_files] - - -def main(): - parser = HfArgumentParser(PretrainArguments) - args = parser.parse_args_into_dataclasses()[0] - do_pretrain(args) - - -def do_pretrain(args): - if is_main_process(args) and args.tensorboard_dir: - tb_writer = SummaryWriter(log_dir=args.tensorboard_dir) - tb_writer.add_text("args", args.to_json_string()) - tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={}) - else: - tb_writer = None - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - ort.set_seed(args.seed) - - device, args = setup_training(args) - - model = prepare_model(args, device) - - logger.info("Running training: Batch size = %d, initial LR = %f", args.train_batch_size, args.learning_rate) - - average_loss = 0.0 - epoch = 0 - training_steps = 0 - - pool = ProcessPoolExecutor(1) - while True: - files = [ - os.path.join(args.input_dir, f) - for f in os.listdir(args.input_dir) - if os.path.isfile(os.path.join(args.input_dir, f)) and "training" in f - ] - files.sort() - random.shuffle(files) - - f_id = 0 - train_dataloader, data_file = create_pretraining_dataset( - get_data_file(f_id, args.world_rank, args.world_size, files), args.max_predictions_per_seq, args - ) - - for f_id in range(1, len(files)): - logger.info("data file %s" % (data_file)) - - dataset_future = pool.submit( - create_pretraining_dataset, - get_data_file(f_id, args.world_rank, args.world_size, files), - args.max_predictions_per_seq, - args, - ) - - train_iter = tqdm(train_dataloader, desc="Iteration") if is_main_process(args) else train_dataloader - for _step, batch in enumerate(train_iter): - training_steps += 1 - batch = [t.to(device) for t in batch] # noqa: PLW2901 - input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch - - loss, _, _ = model.train_step( - input_ids, input_mask, segment_ids, masked_lm_labels, next_sentence_labels - ) - average_loss += loss.item() - - global_step = model._train_step_info.optimization_step - if training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0: - if is_main_process(args): - divisor = args.log_freq * args.gradient_accumulation_steps - if tb_writer: - lr = model.options.lr_scheduler.get_last_lr()[0] - tb_writer.add_scalar("train/summary/scalar/Learning_Rate", lr, global_step) - if args.fp16: - tb_writer.add_scalar("train/summary/scalar/loss_scale_25", loss, global_step) - # TODO: ORTTrainer to expose all_finite - # tb_writer.add_scalar('train/summary/scalar/all_fp16_gradients_finite_859', all_finite, global_step) - tb_writer.add_scalar("train/summary/total_loss", average_loss / divisor, global_step) - - print(f"Step:{global_step} Average Loss = {average_loss / divisor}") - - if global_step >= args.max_steps or global_step >= force_to_stop_max_steps: - if tb_writer: - tb_writer.close() - - if global_step >= args.max_steps: - if args.save_checkpoint: - model.save_checkpoint(os.path.join(args.output_dir, f"checkpoint-{args.world_rank}.ortcp")) - final_loss = average_loss / (args.log_freq * args.gradient_accumulation_steps) - return final_loss - - average_loss = 0 - - del train_dataloader - - train_dataloader, data_file = dataset_future.result(timeout=None) - - epoch += 1 - - -def generate_tensorboard_logdir(root_dir): - current_date_time = datetime.datetime.today() - - dt_string = current_date_time.strftime("BERT_pretrain_%y_%m_%d_%I_%M_%S") - return os.path.join(root_dir, dt_string) - - -class ORTBertPretrainTest(unittest.TestCase): - def setUp(self): - self.output_dir = "/bert_data/hf_data/test_out/bert_pretrain_results" - self.bert_model = "bert-base-uncased" - self.local_rank = -1 - self.world_rank = -1 - self.world_size = 1 - self.max_steps = 300000 - self.learning_rate = 5e-4 - self.max_seq_length = 512 - self.max_predictions_per_seq = 20 - self.input_dir = "/bert_data/hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train" - self.train_batch_size = 4096 - self.gradient_accumulation_steps = 64 - self.fp16 = True - self.allreduce_post_accumulation = True - self.tensorboard_dir = "/bert_data/hf_data/test_out" - - def test_pretrain_throughput(self, process_args=None): - if process_args.sequence_length == 128: - input_dir = "/bert_data/hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train" - else: - input_dir = "/bert_data/hdf5_lower_case_1_seq_len_512_max_pred_80_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train" - - print("process_args.enable_mixed_precision: ", process_args.enable_mixed_precision) - print("process_args.sequence_length: ", process_args.sequence_length) - print("process_args.max_batch_size: ", process_args.max_batch_size) - print("process_args.max_predictions_per_seq: ", process_args.max_predictions_per_seq) - print("process_args.gelu_recompute: ", process_args.gelu_recompute) - print("process_args.attn_dropout_recompute: ", process_args.attn_dropout_recompute) - print("process_args.transformer_layer_recompute: ", process_args.transformer_layer_recompute) - - args = PretrainArguments( - input_dir=input_dir, - output_dir="/bert_data/hf_data/test_out/bert_pretrain_results", - bert_model="bert-large-uncased", - local_rank=self.local_rank, - world_rank=self.world_rank, - world_size=self.world_size, - max_steps=10, - learning_rate=5e-4, - max_seq_length=process_args.sequence_length, - max_predictions_per_seq=process_args.max_predictions_per_seq, - train_batch_size=process_args.max_batch_size, - gradient_accumulation_steps=1, - fp16=process_args.enable_mixed_precision, - gelu_recompute=process_args.gelu_recompute, - attn_dropout_recompute=process_args.attn_dropout_recompute, - transformer_layer_recompute=process_args.transformer_layer_recompute, - allreduce_post_accumulation=True, - # TODO: remove - force_num_hidden_layers=2, - ) - do_pretrain(args) - - def test_pretrain_convergence(self): - args = PretrainArguments( - output_dir=self.output_dir, - bert_model=self.bert_model, - local_rank=self.local_rank, - world_rank=self.world_rank, - world_size=self.world_size, - max_steps=self.max_steps, - learning_rate=self.learning_rate, - max_seq_length=self.max_seq_length, - max_predictions_per_seq=self.max_predictions_per_seq, - train_batch_size=self.train_batch_size, - gradient_accumulation_steps=self.gradient_accumulation_steps, - input_dir=self.input_dir, - fp16=self.fp16, - allreduce_post_accumulation=self.allreduce_post_accumulation, - force_num_hidden_layers=self.force_num_hidden_layers, - tensorboard_dir=generate_tensorboard_logdir("/bert_data/hf_data/test_out/"), - ) - final_loss = do_pretrain(args) - return final_loss - - def test_pretrain_zero(self): - assert self.world_size > 0, "ZeRO test requires a distributed run." - setup_torch_distributed(self.world_rank, self.world_size) - per_gpu_batch_size = 32 - optimization_batch_size = per_gpu_batch_size * self.world_size # set to disable grad accumulation - - self.train_batch_size = optimization_batch_size - self.gradient_accumulation_steps = 1 - self.deepspeed_zero_stage = 1 - self.force_num_hidden_layers = 2 - self.max_seq_length = 32 - self.output_dir = "./bert_pretrain_ckpt" - if self.world_rank == 0: - if os.path.isdir(self.output_dir): - shutil.rmtree(self.output_dir) - os.makedirs(self.output_dir, exist_ok=True) - - torch.distributed.barrier() - - assert os.path.exists(self.output_dir) - - # run a few optimization steps - self.max_steps = 200 - args = PretrainArguments( - output_dir=self.output_dir, - bert_model=self.bert_model, - local_rank=self.local_rank, - world_rank=self.world_rank, - world_size=self.world_size, - max_steps=self.max_steps, - learning_rate=self.learning_rate, - max_seq_length=self.max_seq_length, - max_predictions_per_seq=self.max_predictions_per_seq, - train_batch_size=self.train_batch_size, - gradient_accumulation_steps=self.gradient_accumulation_steps, - input_dir=self.input_dir, - fp16=self.fp16, - allreduce_post_accumulation=self.allreduce_post_accumulation, - force_num_hidden_layers=self.force_num_hidden_layers, - deepspeed_zero_stage=self.deepspeed_zero_stage, - save_checkpoint=True, - ) - do_pretrain(args) - - # ensure all workers reach this point before loading the checkpointed state - torch.distributed.barrier() - - # on rank 0, load the trained state - if args.world_rank == 0: - checkpoint_files = glob.glob(os.path.join(self.output_dir, "checkpoint*.ortcp")) - args.init_state_dict = aggregate_checkpoints(checkpoint_files, pytorch_format=True) - - torch.distributed.barrier() - - # run a single step to get the loss, on rank 0 should be lesser than starting loss - args.save_checkpoint = False - args.max_steps = 1 - args.deepspeed_zero_stage = 0 - final_loss = do_pretrain(args) - return final_loss - - -if __name__ == "__main__": - import sys - - logger.warning("sys.argv: %s", sys.argv) - # usage: - # data parallel training - # mpirun -n 4 python orttraining_run_bert_pretrain.py - # - # single gpu: - # python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_throughput - # [batch size test arguments] - # python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_convergence - # - # pytorch.distributed.launch will not work because ort backend requires MPI to broadcast ncclUniqueId - # calling unpublished get_mpi_context_xxx to get rank/size numbers. - try: - # In case ORT is not built with MPI/NCCL, there are no get_mpi_context_xxx internal apis. - from onnxruntime.capi._pybind_state import get_mpi_context_local_size # noqa: F401 - from onnxruntime.capi._pybind_state import get_mpi_context_world_rank # noqa: F401 - from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_world_size - - has_get_mpi_context_internal_api = True - except ImportError: - has_get_mpi_context_internal_api = False - pass - if has_get_mpi_context_internal_api and get_mpi_context_world_size() > 1: - world_size = get_mpi_context_world_size() - print("get_mpi_context_world_size(): ", world_size) - local_rank = get_mpi_context_local_rank() - - if local_rank == 0: - print("================================================================> os.getpid() = ", os.getpid()) - - test = ORTBertPretrainTest() - test.setUp() - test.local_rank = local_rank - test.world_rank = local_rank - test.world_size = world_size - - if len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_zero": - logger.info("running ORTBertPretrainTest.test_pretrain_zero()...") - final_loss = test.test_pretrain_zero() - logger.info("ORTBertPretrainTest.test_pretrain_zero() rank = %i final loss = %f", local_rank, final_loss) - if local_rank == 0: - test.assertLess(final_loss, 10.2) - else: - test.assertGreater(final_loss, 11.0) - logger.info("ORTBertPretrainTest.test_pretrain_zero() passed") - elif len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_convergence": - logger.info("running ORTBertPretrainTest.test_pretrain_convergence()...") - test.max_steps = 200 - test.force_num_hidden_layers = 8 - final_loss = test.test_pretrain_convergence() - logger.info("ORTBertPretrainTest.test_pretrain_convergence() final loss = %f", final_loss) - test.assertLess(final_loss, 8.5) - logger.info("ORTBertPretrainTest.test_pretrain_convergence() passed") - else: - # https://microsoft.sharepoint.com/teams/ONNX2/_layouts/15/Doc.aspx?sourcedoc={170774be-e1c6-4f8b-a3ae-984f211fe410}&action=edit&wd=target%28ONNX%20Training.one%7C8176133b-c7cb-4ef2-aa9d-3fdad5344c40%2FGitHub%20Master%20Merge%20Schedule%7Cb67f0db1-e3a0-4add-80a6-621d67fd8107%2F%29 - # to make equivalent args for cpp convergence test - test.max_seq_length = 128 - test.max_predictions_per_seq = 20 - test.gradient_accumulation_steps = 16 - - # cpp_batch_size (=64) * grad_acc * world_size - test.train_batch_size = 64 * test.gradient_accumulation_steps * test.world_size - test.max_steps = 300000 - - test.force_num_hidden_layers = None - - # already using Adam (e.g. AdamConfig) - test.learning_rate = 5e-4 - test.warmup_proportion = 0.1 - - final_loss = test.test_pretrain_convergence() - logger.info("ORTBertPretrainTest.test_pretrain_convergence() final loss = %f", final_loss) - else: - # unittest does not accept user defined arguments - # we need to run this script with user defined arguments - if len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_throughput": - run_test_pretrain_throughput, run_test_pretrain_convergence = True, False - sys.argv.remove("ORTBertPretrainTest.test_pretrain_throughput") - elif len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_convergence": - run_test_pretrain_throughput, run_test_pretrain_convergence = False, True - sys.argv.remove("ORTBertPretrainTest.test_pretrain_convergence") - else: - run_test_pretrain_throughput, run_test_pretrain_convergence = True, True - process_args = parse_arguments() - test = ORTBertPretrainTest() - test.setUp() - - if run_test_pretrain_throughput: - logger.info("running single GPU ORTBertPretrainTest.test_pretrain_throughput()...") - test.test_pretrain_throughput(process_args) - logger.info("single GPU ORTBertPretrainTest.test_pretrain_throughput() passed") - - # unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_run_frontend_batch_size_test.py b/orttraining/orttraining/test/python/orttraining_run_frontend_batch_size_test.py deleted file mode 100644 index 3e2d1a7154bfd..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_run_frontend_batch_size_test.py +++ /dev/null @@ -1,67 +0,0 @@ -import collections -import subprocess -import sys - -Config = collections.namedtuple( - "Config", - [ - "enable_mixed_precision", - "sequence_length", - "max_batch_size", - "max_predictions_per_seq", - "gelu_recompute", - "attn_dropout_recompute", - "transformer_layer_recompute", - ], -) - -configs = [ - Config(True, 128, 46, 20, False, False, False), - Config(True, 512, 8, 80, False, False, False), - Config(False, 128, 26, 20, False, False, False), - Config(False, 512, 4, 80, False, False, False), - Config(True, 128, 50, 20, True, False, False), - Config(True, 128, 50, 20, False, True, False), - Config(True, 128, 76, 20, False, False, True), - Config(True, 512, 8, 80, True, False, False), - Config(True, 512, 9, 80, False, True, False), - Config(True, 512, 15, 80, False, False, True), -] - - -def run_with_config(config): - print( - "##### testing name - {}-{} #####".format( - "fp16" if config.enable_mixed_precision else "fp32", config.sequence_length - ) - ) - print("gelu_recompute: ", config.gelu_recompute) - print("attn_dropout_recompute: ", config.attn_dropout_recompute) - print("transformer_layer_recompute: ", config.transformer_layer_recompute) - - cmds = [ - sys.executable, - "orttraining_run_bert_pretrain.py", - "ORTBertPretrainTest.test_pretrain_throughput", - "--sequence_length", - str(config.sequence_length), - "--max_batch_size", - str(config.max_batch_size), - "--max_predictions_per_seq", - str(config.max_predictions_per_seq), - ] - if config.enable_mixed_precision: - cmds.append("--enable_mixed_precision") - if config.gelu_recompute: - cmds.append("--gelu_recompute") - if config.attn_dropout_recompute: - cmds.append("--attn_dropout_recompute") - if config.transformer_layer_recompute: - cmds.append("--transformer_layer_recompute") - - # access to azure storage shared disk is much slower so we need a longer timeout. - subprocess.run(cmds, timeout=1200).check_returncode() # noqa: PLW1510 - - -for config in configs: - run_with_config(config) diff --git a/orttraining/orttraining/test/python/orttraining_run_glue.py b/orttraining/orttraining/test/python/orttraining_run_glue.py deleted file mode 100644 index 794e2f8cc7240..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_run_glue.py +++ /dev/null @@ -1,323 +0,0 @@ -# adapted from run_glue.py of huggingface transformers - -import dataclasses # noqa: F401 -import logging -import os -import unittest -from dataclasses import dataclass, field -from typing import Dict, Optional - -import numpy as np -from numpy.testing import assert_allclose -from transformers import ( - AutoConfig, - AutoModelForSequenceClassification, - AutoTokenizer, - EvalPrediction, - GlueDataset, - GlueDataTrainingArguments, - TrainingArguments, - glue_compute_metrics, - glue_output_modes, - glue_tasks_num_labels, - set_seed, -) - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer # noqa: F401 - -try: - from onnxruntime.capi._pybind_state import get_mpi_context_local_size # noqa: F401 - from onnxruntime.capi._pybind_state import get_mpi_context_world_rank # noqa: F401 - from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_world_size - - has_get_mpi_context_internal_api = True -except ImportError: - has_get_mpi_context_internal_api = False - pass - - -import torch # noqa: F401 -from orttraining_transformer_trainer import ORTTransformerTrainer - -logger = logging.getLogger(__name__) - - -def verify_old_and_new_api_are_equal(results_per_api): - new_api_results = results_per_api[True] - old_api_results = results_per_api[False] - for key in new_api_results: - assert_allclose(new_api_results[key], old_api_results[key]) - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. - """ - - model_name_or_path: str = field(metadata={"help": "model identifier from huggingface.co/models"}) - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - cache_dir: Optional[str] = field( - default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} - ) - - -class ORTGlueTest(unittest.TestCase): - def setUp(self): - # configurations not to be changed accoss tests - self.max_seq_length = 128 - self.train_batch_size = 8 - self.learning_rate = 2e-5 - self.num_train_epochs = 3.0 - self.local_rank = -1 - self.world_size = 1 - self.overwrite_output_dir = True - self.gradient_accumulation_steps = 1 - self.data_dir = "/bert_data/hf_data/glue_data/" - self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "glue_test_output/") - self.cache_dir = "/tmp/glue/" - self.logging_steps = 10 - - def test_roberta_with_mrpc(self): - expected_acc = 0.85 - expected_f1 = 0.88 - expected_loss = 0.35 - results = self.run_glue(model_name="roberta-base", task_name="MRPC", fp16=False) - - assert results["acc"] >= expected_acc - assert results["f1"] >= expected_f1 - assert results["loss"] <= expected_loss - - def test_roberta_fp16_with_mrpc(self): - expected_acc = 0.87 - expected_f1 = 0.90 - expected_loss = 0.33 - - results = self.run_glue(model_name="roberta-base", task_name="MRPC", fp16=True) - - assert results["acc"] >= expected_acc - assert results["f1"] >= expected_f1 - assert results["loss"] <= expected_loss - - def test_bert_with_mrpc(self): - if self.local_rank == -1: - expected_acc = 0.83 - expected_f1 = 0.88 - expected_loss = 0.44 - elif self.local_rank == 0: - expected_acc = 0.81 - expected_f1 = 0.86 - expected_loss = 0.44 - - results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=False) - - if self.local_rank in [-1, 0]: - assert results["acc"] >= expected_acc - assert results["f1"] >= expected_f1 - assert results["loss"] <= expected_loss - - def test_bert_fp16_with_mrpc(self): - expected_acc = 0.84 - expected_f1 = 0.88 - expected_loss = 0.46 - - results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=True) - - assert results["acc"] >= expected_acc - assert results["f1"] >= expected_f1 - assert results["loss"] <= expected_loss - - def model_to_desc(self, model_name, model): - if model_name.startswith("bert") or model_name.startswith("xlnet"): - model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", "max_seq_len_in_batch"], - ), - ( - "token_type_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "labels", - [ - "batch", - ], - ), - ], - "outputs": [("loss", [], True), ("logits", ["batch", 2])], - } - elif model_name.startswith("roberta"): - model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", "max_seq_len_in_batch"], - ), - ( - "labels", - [ - "batch", - ], - ), - ], - "outputs": [("loss", [], True), ("logits", ["batch", 2])], - } - else: - raise RuntimeError(f"unsupported base model name {model_name}.") - - return model_desc - - def run_glue(self, model_name, task_name, fp16): - model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir) - data_args = GlueDataTrainingArguments( - task_name=task_name, data_dir=os.path.join(self.data_dir, task_name), max_seq_length=self.max_seq_length - ) - - training_args = TrainingArguments( - output_dir=os.path.join(self.output_dir, task_name), - do_train=True, - do_eval=True, - per_gpu_train_batch_size=self.train_batch_size, - learning_rate=self.learning_rate, - num_train_epochs=self.num_train_epochs, - local_rank=self.local_rank, - overwrite_output_dir=self.overwrite_output_dir, - gradient_accumulation_steps=self.gradient_accumulation_steps, - fp16=fp16, - logging_steps=self.logging_steps, - ) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, - ) - logger.warning( - "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", - training_args.local_rank, - training_args.device, - training_args.n_gpu, - bool(training_args.local_rank != -1), - training_args.fp16, - ) - logger.info("Training/evaluation parameters %s", training_args) - - set_seed(training_args.seed) - onnxruntime.set_seed(training_args.seed) - - try: - num_labels = glue_tasks_num_labels[data_args.task_name] - output_mode = glue_output_modes[data_args.task_name] - except KeyError: - raise ValueError("Task not found: %s" % (data_args.task_name)) # noqa: B904 - - config = AutoConfig.from_pretrained( - model_args.config_name if model_args.config_name else model_args.model_name_or_path, - num_labels=num_labels, - finetuning_task=data_args.task_name, - cache_dir=model_args.cache_dir, - ) - tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - ) - - model = AutoModelForSequenceClassification.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - ) - - train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None - - eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") if training_args.do_eval else None - - def compute_metrics(p: EvalPrediction) -> Dict: - if output_mode == "classification": - preds = np.argmax(p.predictions, axis=1) - elif output_mode == "regression": - preds = np.squeeze(p.predictions) - return glue_compute_metrics(data_args.task_name, preds, p.label_ids) - - model_desc = self.model_to_desc(model_name, model) - # Initialize the ORTTrainer within ORTTransformerTrainer - trainer = ORTTransformerTrainer( - model=model, - model_desc=model_desc, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - compute_metrics=compute_metrics, - world_size=self.world_size, - ) - - # Training - if training_args.do_train: - trainer.train() - trainer.save_model() - - # Evaluation - results = {} - if training_args.do_eval and training_args.local_rank in [-1, 0]: - logger.info("*** Evaluate ***") - - result = trainer.evaluate() - - logger.info(f"***** Eval results {data_args.task_name} *****") - for key, value in result.items(): - logger.info(" %s = %s", key, value) - - results.update(result) - - return results - - -if __name__ == "__main__": - if has_get_mpi_context_internal_api: - local_rank = get_mpi_context_local_rank() - world_size = get_mpi_context_world_size() - else: - local_rank = -1 - world_size = 1 - - if world_size > 1: - # mpi launch - logger.warning("mpirun launch, local_rank / world_size: %s : % s", local_rank, world_size) - - # TrainingArguments._setup_devices will call torch.distributed.init_process_group(backend="nccl") - # pytorch expects following environment settings (which would be set if launched with torch.distributed.launch). - - os.environ["RANK"] = str(local_rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = "29500" - - from onnxruntime.capi._pybind_state import set_cuda_device_id - - set_cuda_device_id(local_rank) - - test = ORTGlueTest() - test.setUp() - test.local_rank = local_rank - test.world_size = world_size - test.test_bert_with_mrpc() - else: - unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py b/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py deleted file mode 100644 index 92db204593bcd..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py +++ /dev/null @@ -1,281 +0,0 @@ -# adapted from run_multiple_choice.py of huggingface transformers -# https://github.com/huggingface/transformers/blob/master/examples/multiple-choice/run_multiple_choice.py - -import dataclasses # noqa: F401 -import logging -import os -import unittest -from dataclasses import dataclass, field -from typing import Dict, Optional - -import numpy as np -import torch # noqa: F401 -from numpy.testing import assert_allclose # noqa: F401 -from orttraining_run_glue import verify_old_and_new_api_are_equal # noqa: F401 -from orttraining_transformer_trainer import ORTTransformerTrainer -from transformers import HfArgumentParser # noqa: F401 -from transformers import Trainer # noqa: F401 -from transformers import ( - AutoConfig, - AutoModelForMultipleChoice, - AutoTokenizer, - EvalPrediction, - TrainingArguments, - set_seed, -) -from utils_multiple_choice import MultipleChoiceDataset, Split, SwagProcessor - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer # noqa: F401 - -logger = logging.getLogger(__name__) - - -def simple_accuracy(preds, labels): - return (preds == labels).mean() - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. - """ - - model_name_or_path: str = field(metadata={"help": "model identifier from huggingface.co/models"}) - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - cache_dir: Optional[str] = field( - default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} - ) - - -@dataclass -class DataTrainingArguments: - """ - Arguments pertaining to what data we are going to input our model for training and eval. - """ - - task_name: str = field(metadata={"help": "The name of the task to train on."}) - data_dir: str = field(metadata={"help": "Should contain the data files for the task."}) - max_seq_length: int = field( - default=128, - metadata={ - "help": "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded." - }, - ) - overwrite_cache: bool = field(default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}) - - -class ORTMultipleChoiceTest(unittest.TestCase): - def setUp(self): - # configurations not to be changed accoss tests - self.max_seq_length = 80 - self.train_batch_size = 16 - self.eval_batch_size = 2 - self.learning_rate = 2e-5 - self.num_train_epochs = 1.0 - self.local_rank = -1 - self.overwrite_output_dir = True - self.gradient_accumulation_steps = 8 - self.data_dir = "/bert_data/hf_data/swag/swagaf/data" - self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "multiple_choice_test_output/") - self.cache_dir = "/tmp/multiple_choice/" - self.logging_steps = 10 - self.rtol = 2e-01 - - def test_bert_with_swag(self): - expected_acc = 0.75 - expected_loss = 0.64 - - results = self.run_multiple_choice(model_name="bert-base-cased", task_name="swag", fp16=False) - assert results["acc"] >= expected_acc - assert results["loss"] <= expected_loss - - def test_bert_fp16_with_swag(self): - # larger batch can be handled with mixed precision - self.train_batch_size = 32 - - expected_acc = 0.73 - expected_loss = 0.68 - - results = self.run_multiple_choice(model_name="bert-base-cased", task_name="swag", fp16=True) - assert results["acc"] >= expected_acc - assert results["loss"] <= expected_loss - - def run_multiple_choice(self, model_name, task_name, fp16): - model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir) - data_args = DataTrainingArguments( - task_name=task_name, data_dir=self.data_dir, max_seq_length=self.max_seq_length - ) - - training_args = TrainingArguments( - output_dir=os.path.join(self.output_dir, task_name), - do_train=True, - do_eval=True, - per_gpu_train_batch_size=self.train_batch_size, - per_gpu_eval_batch_size=self.eval_batch_size, - learning_rate=self.learning_rate, - num_train_epochs=self.num_train_epochs, - local_rank=self.local_rank, - overwrite_output_dir=self.overwrite_output_dir, - gradient_accumulation_steps=self.gradient_accumulation_steps, - fp16=fp16, - logging_steps=self.logging_steps, - ) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, - ) - logger.warning( - "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", - training_args.local_rank, - training_args.device, - training_args.n_gpu, - bool(training_args.local_rank != -1), - training_args.fp16, - ) - logger.info("Training/evaluation parameters %s", training_args) - - set_seed(training_args.seed) - onnxruntime.set_seed(training_args.seed) - - try: - processor = SwagProcessor() - label_list = processor.get_labels() - num_labels = len(label_list) - except KeyError: - raise ValueError("Task not found: %s" % (data_args.task_name)) # noqa: B904 - - config = AutoConfig.from_pretrained( - model_args.config_name if model_args.config_name else model_args.model_name_or_path, - num_labels=num_labels, - finetuning_task=data_args.task_name, - cache_dir=model_args.cache_dir, - ) - - tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - ) - - model = AutoModelForMultipleChoice.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - ) - - # Get datasets - train_dataset = ( - MultipleChoiceDataset( - data_dir=data_args.data_dir, - tokenizer=tokenizer, - task=data_args.task_name, - processor=processor, - max_seq_length=data_args.max_seq_length, - overwrite_cache=data_args.overwrite_cache, - mode=Split.train, - ) - if training_args.do_train - else None - ) - eval_dataset = ( - MultipleChoiceDataset( - data_dir=data_args.data_dir, - tokenizer=tokenizer, - task=data_args.task_name, - processor=processor, - max_seq_length=data_args.max_seq_length, - overwrite_cache=data_args.overwrite_cache, - mode=Split.dev, - ) - if training_args.do_eval - else None - ) - - def compute_metrics(p: EvalPrediction) -> Dict: - preds = np.argmax(p.predictions, axis=1) - return {"acc": simple_accuracy(preds, p.label_ids)} - - if model_name.startswith("bert"): - model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "token_type_ids", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "labels", - ["batch", num_labels], - ), - ], - "outputs": [("loss", [], True), ("reshaped_logits", ["batch", num_labels])], - } - else: - model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "labels", - ["batch", num_labels], - ), - ], - "outputs": [("loss", [], True), ("reshaped_logits", ["batch", num_labels])], - } - - # Initialize the ORTTrainer within ORTTransformerTrainer - trainer = ORTTransformerTrainer( - model=model, - model_desc=model_desc, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - compute_metrics=compute_metrics, - ) - - # Training - if training_args.do_train: - trainer.train() - trainer.save_model() - - # Evaluation - results = {} - if training_args.do_eval and training_args.local_rank in [-1, 0]: - logger.info("*** Evaluate ***") - - result = trainer.evaluate() - - logger.info(f"***** Eval results {data_args.task_name} *****") - for key, value in result.items(): - logger.info(" %s = %s", key, value) - - results.update(result) - - return results - - -if __name__ == "__main__": - unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py b/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py deleted file mode 100644 index 71e6bb8e4d2f2..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py +++ /dev/null @@ -1,6 +0,0 @@ -from orttraining_test_layer_norm_transform import layer_norm_transform # noqa: F401 -from orttraining_test_model_transform import add_expand_shape, add_name, fix_transpose # noqa: F401 - - -def postprocess_model(model): - add_name(model) diff --git a/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py b/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py deleted file mode 100644 index 21372caaf6779..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# orttraining_test_checkpoint_storage.py - -import os -import pickle -import shutil - -import numpy as np -import pytest -import torch - -from onnxruntime.training import _checkpoint_storage - -# Helper functions - - -def _equals(a, b): - """Checks recursively if two dictionaries are equal""" - if isinstance(a, dict): - return all(not (key not in b or not _equals(a[key], b[key])) for key in a) - else: - if isinstance(a, bytes): - a = a.decode() - if isinstance(b, bytes): - b = b.decode() - are_equal = a == b - return are_equal if isinstance(are_equal, bool) else are_equal.all() - - return False - - -def _numpy_types(obj_value): - """Return a bool indicating whether or not the input obj_value is a numpy type object - - Recursively checks if the obj_value (could be a dictionary) is a numpy type object. - Exceptions are str and bytes. - - Returns true if object is numpy type, str, or bytes - False if any other type - """ - if not isinstance(obj_value, dict): - return isinstance(obj_value, (str, bytes)) or type(obj_value).__module__ == np.__name__ - - return all(_numpy_types(value) for _, value in obj_value.items()) - - -def _get_dict(separated_key): - """Create dummy dictionary with different datatypes - - Returns the tuple of the entire dummy dictionary created, key argument as a dictionary for _checkpoint_storage.load - function and the value for that key in the original dictionary - - For example the complete dictionary is represented by: - { - 'int1':1, - 'int2': 2, - 'int_list': [1,2,3,5,6], - 'dict1': { - 'np_array': np.arange(100), - 'dict2': {'int3': 3, 'int4': 4}, - 'str1': "onnxruntime" - }, - 'bool1': bool(True), - 'int5': 5, - 'float1': 2.345, - 'np_array_float': np.array([1.234, 2.345, 3.456]), - 'np_array_float_3_dim': np.array([[[1,2],[3,4]], [[5,6],[7,8]]]) - } - - if the input key is ['dict1', 'str1'], then the key argument returned is 'dict1/str1' - and the value corresponding to that is "onnxruntime" - - so, for the above example, the returned tuple is: - (original_dict, {'key': 'dict1/str1', "onnxruntime") - """ - test_dict = { - "int1": 1, - "int2": 2, - "int_list": [1, 2, 3, 5, 6], - "dict1": {"np_array": np.arange(100), "dict2": {"int3": 3, "int4": 4}, "str1": "onnxruntime"}, - "bool1": True, - "int5": 5, - "float1": 2.345, - "np_array_float": np.array([1.234, 2.345, 3.456]), - "np_array_float_3_dim": np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]), - } - key = "" - expected_val = test_dict - for single_key in separated_key: - key += single_key + "/" - expected_val = expected_val[single_key] - return test_dict, {"key": key} if len(separated_key) > 0 else dict(), expected_val - - -class _CustomClass: - """Custom object that encpsulates dummy values for loss, epoch and train_step""" - - def __init__(self): - self._loss = 1.23 - self._epoch = 12000 - self._train_step = 25 - - def __eq__(self, other): - if isinstance(other, _CustomClass): - return self._loss == other._loss and self._epoch == other._epoch and self._train_step == other._train_step - - -# Test fixtures - - -@pytest.yield_fixture(scope="function") -def checkpoint_storage_test_setup(): - checkpoint_dir = os.path.abspath("checkpoint_dir/") - if not os.path.exists(checkpoint_dir): - os.makedirs(checkpoint_dir, exist_ok=True) - pytest.checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.ortcp") - yield "checkpoint_storage_test_setup" - shutil.rmtree(checkpoint_dir) - - -@pytest.yield_fixture(scope="function") -def checkpoint_storage_test_parameterized_setup(request, checkpoint_storage_test_setup): - yield request.param - - -# Tests - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", - [ - _get_dict([]), - _get_dict(["int1"]), - _get_dict(["dict1"]), - _get_dict(["dict1", "dict2"]), - _get_dict(["dict1", "dict2", "int4"]), - _get_dict(["dict1", "str1"]), - _get_dict(["bool1"]), - _get_dict(["float1"]), - _get_dict(["np_array_float"]), - ], - indirect=True, -) -def test_checkpoint_storage_saved_dict_matches_loaded(checkpoint_storage_test_parameterized_setup): - to_save = checkpoint_storage_test_parameterized_setup[0] - key_arg = checkpoint_storage_test_parameterized_setup[1] - expected = checkpoint_storage_test_parameterized_setup[2] - _checkpoint_storage.save(to_save, pytest.checkpoint_path) - loaded = _checkpoint_storage.load(pytest.checkpoint_path, **key_arg) - assert _equals(loaded, expected) - assert _numpy_types(loaded) - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", - [{"int_set": {1, 2, 3, 4, 5}}, {"str_set": {"one", "two"}}, [1, 2, 3], 2.352], - indirect=True, -) -def test_checkpoint_storage_saving_non_supported_types_fails(checkpoint_storage_test_parameterized_setup): - to_save = checkpoint_storage_test_parameterized_setup - with pytest.raises(Exception): # noqa: B017 - _checkpoint_storage.save(to_save, pytest.checkpoint_path) - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", - [ - ({"int64_tensor": torch.tensor(np.arange(100))}, "int64_tensor", torch.int64, np.int64), - ({"int32_tensor": torch.tensor(np.arange(100), dtype=torch.int32)}, "int32_tensor", torch.int32, np.int32), - ({"int16_tensor": torch.tensor(np.arange(100), dtype=torch.int16)}, "int16_tensor", torch.int16, np.int16), - ({"int8_tensor": torch.tensor(np.arange(100), dtype=torch.int8)}, "int8_tensor", torch.int8, np.int8), - ({"float64_tensor": torch.tensor(np.array([1.0, 2.0]))}, "float64_tensor", torch.float64, np.float64), - ( - {"float32_tensor": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32)}, - "float32_tensor", - torch.float32, - np.float32, - ), - ( - {"float16_tensor": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float16)}, - "float16_tensor", - torch.float16, - np.float16, - ), - ], - indirect=True, -) -def test_checkpoint_storage_saving_tensor_datatype(checkpoint_storage_test_parameterized_setup): - tensor_dict = checkpoint_storage_test_parameterized_setup[0] - tensor_name = checkpoint_storage_test_parameterized_setup[1] - tensor_dtype = checkpoint_storage_test_parameterized_setup[2] - np_dtype = checkpoint_storage_test_parameterized_setup[3] - - _checkpoint_storage.save(tensor_dict, pytest.checkpoint_path) - - loaded = _checkpoint_storage.load(pytest.checkpoint_path) - assert isinstance(loaded[tensor_name], np.ndarray) - assert tensor_dict[tensor_name].dtype == tensor_dtype - assert loaded[tensor_name].dtype == np_dtype - assert (tensor_dict[tensor_name].numpy() == loaded[tensor_name]).all() - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", - [ - ({"two_dim": torch.ones([2, 4], dtype=torch.float64)}, "two_dim"), - ({"three_dim": torch.ones([2, 4, 6], dtype=torch.float64)}, "three_dim"), - ({"four_dim": torch.ones([2, 4, 6, 8], dtype=torch.float64)}, "four_dim"), - ], - indirect=True, -) -def test_checkpoint_storage_saving_multiple_dimension_tensors(checkpoint_storage_test_parameterized_setup): - tensor_dict = checkpoint_storage_test_parameterized_setup[0] - tensor_name = checkpoint_storage_test_parameterized_setup[1] - - _checkpoint_storage.save(tensor_dict, pytest.checkpoint_path) - - loaded = _checkpoint_storage.load(pytest.checkpoint_path) - assert isinstance(loaded[tensor_name], np.ndarray) - assert (tensor_dict[tensor_name].numpy() == loaded[tensor_name]).all() - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", [{}, {"a": {}}, {"a": {"b": {}}}], indirect=True -) -def test_checkpoint_storage_saving_and_loading_empty_dictionaries_succeeds(checkpoint_storage_test_parameterized_setup): - saved = checkpoint_storage_test_parameterized_setup - _checkpoint_storage.save(saved, pytest.checkpoint_path) - - loaded = _checkpoint_storage.load(pytest.checkpoint_path) - assert _equals(saved, loaded) - - -def test_checkpoint_storage_load_file_that_does_not_exist_fails(checkpoint_storage_test_setup): - with pytest.raises(Exception): # noqa: B017 - _checkpoint_storage.load(pytest.checkpoint_path) - - -def test_checkpoint_storage_for_custom_user_dict_succeeds(checkpoint_storage_test_setup): - custom_class = _CustomClass() - user_dict = {"tensor1": torch.tensor(np.arange(100), dtype=torch.float32), "custom_class": custom_class} - - pickled_bytes = pickle.dumps(user_dict).hex() - to_save = {"a": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32), "user_dict": pickled_bytes} - _checkpoint_storage.save(to_save, pytest.checkpoint_path) - - loaded_dict = _checkpoint_storage.load(pytest.checkpoint_path) - assert (loaded_dict["a"] == to_save["a"].numpy()).all() - try: # noqa: SIM105 - loaded_dict["user_dict"] = loaded_dict["user_dict"].decode() - except AttributeError: - pass - loaded_obj = pickle.loads(bytes.fromhex(loaded_dict["user_dict"])) - - assert torch.all(loaded_obj["tensor1"].eq(user_dict["tensor1"])) - assert loaded_obj["custom_class"] == custom_class diff --git a/orttraining/orttraining/test/python/orttraining_test_data_loader.py b/orttraining/orttraining/test/python/orttraining_test_data_loader.py index aa15b44ae0d66..0009d2d3d7e1b 100644 --- a/orttraining/orttraining/test/python/orttraining_test_data_loader.py +++ b/orttraining/orttraining/test/python/orttraining_test_data_loader.py @@ -4,8 +4,6 @@ import torch from torch.utils.data import DataLoader, Dataset -from onnxruntime.capi.ort_trainer import generate_sample - global_rng = random.Random() @@ -41,6 +39,16 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None): return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() +def generate_sample(desc, device=None): + """Generate a sample based on the description""" + # symbolic dimensions are described with strings. set symbolic dimensions to be 1 + size = [s if isinstance(s, (int)) else 1 for s in desc.shape_] + if desc.num_classes_: + return torch.randint(0, desc.num_classes_, size, dtype=desc.dtype_).to(device) + else: + return torch.randn(size, dtype=desc.dtype_).to(device) + + class OrtTestDataset(Dataset): def __init__(self, input_desc, seq_len, dataset_len, device): import copy diff --git a/orttraining/orttraining/test/python/orttraining_test_debuggability.py b/orttraining/orttraining/test/python/orttraining_test_debuggability.py deleted file mode 100644 index 499f0ba7a1ff5..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_debuggability.py +++ /dev/null @@ -1,40 +0,0 @@ -import pytest -import torch -from _test_commons import _load_pytorch_transformer_model - -from onnxruntime import set_seed -from onnxruntime.training import optim, orttrainer - -############################################################################### -# Testing starts here ######################################################### -############################################################################### - - -@pytest.mark.parametrize( - "seed, device", - [ - (24, "cuda"), - ], -) -def testORTTransformerModelExport(seed, device): - # Common setup - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": { - "check_model_export": True, - }, - "device": { - "id": device, - }, - } - ) - - # Setup for the first ORTTRainer run - torch.manual_seed(seed) - set_seed(seed) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) - first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - data, targets = batcher_fn(train_data, 0) - _ = first_trainer.train_step(data, targets) - assert first_trainer._onnx_model is not None diff --git a/orttraining/orttraining/test/python/orttraining_test_dort.py b/orttraining/orttraining/test/python/orttraining_test_dort.py index 88d9c00984d3e..2a7012787be6e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_dort.py +++ b/orttraining/orttraining/test/python/orttraining_test_dort.py @@ -19,6 +19,7 @@ class TestTorchDynamoOrt(unittest.TestCase): def setUp(self): # Make computation deterministic. torch.manual_seed(42) + print(f"TestTorchDynamoOrt uses PyTorch version {torch.__version__}") def test_elementwise_model(self): torch._dynamo.reset() diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis.py index 506aafbe9f618..a3e666dd404f2 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis.py @@ -27,7 +27,7 @@ def run_training_apis_python_api_tests(cwd, log): log.debug("Running: ort training api tests") - command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_python_bindings.py"] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ort_apis_py_bindings.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() @@ -37,7 +37,7 @@ def run_onnxblock_tests(cwd, log): log.debug("Running: onnxblock tests") - command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_onnxblock.py"] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ort_apis_onnxblock.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() diff --git a/orttraining/orttraining/test/python/orttraining_test_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py similarity index 100% rename from orttraining/orttraining/test/python/orttraining_test_onnxblock.py rename to orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py diff --git a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py similarity index 99% rename from orttraining/orttraining/test/python/orttraining_test_python_bindings.py rename to orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py index d5c37b3e36ee7..34d8c24ccfab4 100644 --- a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py @@ -11,7 +11,7 @@ import onnx import pytest import torch -from orttraining_test_onnxblock import _get_models +from orttraining_test_ort_apis_onnxblock import _get_models import onnxruntime.training.onnxblock as onnxblock from onnxruntime import OrtValue, SessionOptions diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 13024b81f4b3c..ad0e5d8beba3d 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -5761,6 +5761,7 @@ def run_step(model, input, positions): ("MatMul", 1), ("Dropout", 0), ("LayerNormalization", 0), + ("LayerNormalization", 1), ("Cast", 0), ("BiasGelu", 0), ("Gelu", 0), @@ -5773,12 +5774,18 @@ def test_ops_for_padding_elimination(test_cases): test_op = test_cases[0] case = test_cases[1] + vocab_size, hidden_size = 50265, 768 + batch_size, max_seq_length = 8, 128 + class ToyModel(torch.nn.Module): def __init__(self, vocab_size, hidden_size, pad_token_id): super().__init__() self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id) if test_op == "LayerNormalization": - self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-05) + if case == 0: + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-05) + else: + self.LayerNorm = nn.LayerNorm([max_seq_length, hidden_size], eps=1e-05) self.hidden_size = hidden_size # test test_elementwise op for padding elimination @@ -5889,8 +5896,6 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): batched_inputs.append(torch.cat((input_id, padding))) return torch.stack(batched_inputs) - vocab_size, hidden_size = 50265, 768 - batch_size, max_seq_length = 8, 128 device = "cuda" model = ORTModule(ToyModel(vocab_size, hidden_size, 1).to(device)) x = generate_inputs(batch_size, max_seq_length, vocab_size) @@ -5908,7 +5913,7 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 3 else: assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 2 - gathergrad_node = next(node for node in training_model.graph.node if node.op_type == "PadAndUnflatten") + recover_pad_node = next(node for node in training_model.graph.node if node.op_type == "PadAndUnflatten") def find_input_node_type(model, arg): result = [] @@ -5917,14 +5922,14 @@ def find_input_node_type(model, arg): result.append(node) return result[0].op_type if len(result) == 1 else None - gathergrad_input_optypes = [find_input_node_type(training_model, arg) for arg in gathergrad_node.input] + recover_pad_input_optypes = [find_input_node_type(training_model, arg) for arg in recover_pad_node.input] if test_op == "Add" or test_op == "Mul" or test_op == "Sub": - assert test_op in gathergrad_input_optypes + assert test_op in recover_pad_input_optypes else: if case == 0: - assert test_op in gathergrad_input_optypes + assert test_op in recover_pad_input_optypes else: - assert "ATen" in gathergrad_input_optypes + assert "ATen" in recover_pad_input_optypes del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] diff --git a/orttraining/orttraining/test/python/orttraining_test_hooks.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_hooks.py similarity index 100% rename from orttraining/orttraining/test/python/orttraining_test_hooks.py rename to orttraining/orttraining/test/python/orttraining_test_ortmodule_hooks.py diff --git a/orttraining/orttraining/test/python/orttraining_test_onnx_ops_ortmodule.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py similarity index 100% rename from orttraining/orttraining/test/python/orttraining_test_onnx_ops_ortmodule.py rename to orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py index e66582bda9ed1..0c381d70ca4c1 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py @@ -543,6 +543,8 @@ def _gen_inputs(dtype): ([123, 4, 5, 6], [2], False), ([16, 8, 16, 8], [1, 3], True), ([16, 8, 16, 8], [0, 2], False), + ([16, 8, 16, 8], [0, 1, 2, 3], True), + ([16, 1, 16, 8], [0, 1, 2, 3], False), ], ) def test_reduce_op(op_type, onnx_dtype, input_shape_and_reduce_info): @@ -871,3 +873,34 @@ def _gen_inputs(dtype): return [torch.rand(m_n_k[0], m_n_k[2], dtype=dtype, device=DEVICE, requires_grad=True)] _run_tunable_op_test(NeuralNetGemm, dtype, _gen_inputs, "GemmTunableOp", 2) + + +def test_user_config(): + n, d, h, w = 8, 768, 12, 64 + dtype = torch.float32 + + class NeuralNetElementwise(torch.nn.Module): + def forward(self, input1, input2, input3, input4): + return input1 + input2 - input3 * input4 + + def _gen_inputs(dtype): + return [ + torch.rand(n, d, h, w, dtype=dtype, device=DEVICE, requires_grad=True), + torch.rand(w, dtype=dtype, device=DEVICE, requires_grad=True), + torch.rand(d, 1, 1, dtype=dtype, device=DEVICE, requires_grad=True), + torch.rand(n, 1, h, w, dtype=dtype, device=DEVICE, requires_grad=True), + ] + + user_config = ( + '{"ops": {"Add": {"versions": [13, 14]}, "Mul": {"versions": [13, 14]}}, ' + '"initializer": "scalar", "min_nodes": 2}' + ) + with open("user_config.json", "w", encoding="UTF-8") as f: + f.write(user_config) + os.environ["ORTMODULE_TRITON_CONFIG_FILE"] = "./user_config.json" + + # Mul is not supported, the graph is splited to 2 subgraphs with single Op, which will not be fused to TritonOp. + _run_module_test(NeuralNetElementwise, dtype, _gen_inputs, 0) + + del os.environ["ORTMODULE_TRITON_CONFIG_FILE"] + os.remove(os.path.join(os.getcwd(), "user_config.json")) diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py deleted file mode 100644 index 45b87b32f7d64..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py +++ /dev/null @@ -1,1283 +0,0 @@ -import copy # noqa: F401 -import inspect # noqa: F401 -import math # noqa: F401 -import os -from functools import partial - -import _test_commons -import _test_helpers -import onnx -import pytest -import torch -from numpy.testing import assert_allclose - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription -from onnxruntime.capi.ort_trainer import LossScaler as Legacy_LossScaler -from onnxruntime.capi.ort_trainer import ModelDescription as Legacy_ModelDescription -from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer -from onnxruntime.training import amp, optim, orttrainer - -############################################################################### -# Helper functions ############################################################ -############################################################################### - - -def generate_random_input_from_model_desc(desc, seed=1, device="cuda:0"): - """Generates a sample input for the BERT model using the model desc""" - - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - dtype = torch.int64 - vocab_size = 30528 - num_classes = [vocab_size, 2, 2, vocab_size, 2] - dims = {"batch_size": 16, "seq_len": 1} - sample_input = [] - for index, input in enumerate(desc["inputs"]): - size = [] - for s in input[1]: - if isinstance(s, (int)): - size.append(s) - else: - size.append(dims[s] if s in dims else 1) - sample_input.append(torch.randint(0, num_classes[index], tuple(size), dtype=dtype).to(device)) - return sample_input - - -# EXPERIMENTAL HELPER FUNCTIONS - - -def bert_model_description(dynamic_shape=True): - """Creates the model description dictionary with static dimensions""" - - if dynamic_shape: - model_desc = { - "inputs": [ - ("input_ids", ["batch_size", "seq_len"]), - ( - "segment_ids", - ["batch_size", "seq_len"], - ), - ( - "input_mask", - ["batch_size", "seq_len"], - ), - ( - "masked_lm_labels", - ["batch_size", "seq_len"], - ), - ( - "next_sentence_labels", - [ - "batch_size", - ], - ), - ], - "outputs": [("loss", [], True)], - } - else: - batch_size = 16 - seq_len = 1 - model_desc = { - "inputs": [ - ("input_ids", [batch_size, seq_len]), - ( - "segment_ids", - [batch_size, seq_len], - ), - ( - "input_mask", - [batch_size, seq_len], - ), - ( - "masked_lm_labels", - [batch_size, seq_len], - ), - ( - "next_sentence_labels", - [ - batch_size, - ], - ), - ], - "outputs": [("loss", [], True)], - } - return model_desc - - -def optimizer_parameters(model): - """A method to assign different hyper parameters for different model parameter groups""" - - no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] - no_decay_param_group = [] - for initializer in model.graph.initializer: - if any(key in initializer.name for key in no_decay_keys): - no_decay_param_group.append(initializer.name) - params = [ - { - "params": no_decay_param_group, - "alpha": 0.9, - "beta": 0.999, - "lambda_coef": 0.0, - "epsilon": 1e-6, - "do_bias_correction": False, - } - ] - - return params - - -def load_bert_onnx_model(): - bert_onnx_model_path = os.path.join("testdata", "bert_toy_postprocessed.onnx") - model = onnx.load(bert_onnx_model_path) - return model - - -class CustomLossScaler(amp.LossScaler): - def __init__(self, loss_scale=float(1 << 16)): - super().__init__(loss_scale) - self._initial_loss_scale = loss_scale - self.loss_scale = loss_scale - - def reset(self): - self.loss_scale = self._initial_loss_scale - - def update(self, train_step_info): - self.loss_scale *= 0.9 - return self.loss_scale - - -# LEGACY HELPER FUNCTIONS - - -class LegacyCustomLossScaler: - def __init__(self, loss_scale=float(1 << 16)): - self._initial_loss_scale = loss_scale - self.loss_scale_ = loss_scale - - def reset(self): - self.loss_scale_ = self._initial_loss_scale - - def update_loss_scale(self, is_all_finite): - self.loss_scale_ *= 0.9 - - -def legacy_model_params(lr, device=torch.device("cuda", 0)): # noqa: B008 - legacy_model_desc = legacy_bert_model_description() - learning_rate_description = legacy_ort_trainer_learning_rate_description() - learning_rate = torch.tensor([lr]).to(device) - return (legacy_model_desc, learning_rate_description, learning_rate) - - -def legacy_ort_trainer_learning_rate_description(): - return Legacy_IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ) - - -def legacy_bert_model_description(): - input_ids_desc = Legacy_IODescription("input_ids", ["batch", "max_seq_len_in_batch"]) - segment_ids_desc = Legacy_IODescription("segment_ids", ["batch", "max_seq_len_in_batch"]) - input_mask_desc = Legacy_IODescription("input_mask", ["batch", "max_seq_len_in_batch"]) - masked_lm_labels_desc = Legacy_IODescription("masked_lm_labels", ["batch", "max_seq_len_in_batch"]) - next_sentence_labels_desc = Legacy_IODescription( - "next_sentence_labels", - [ - "batch", - ], - ) - loss_desc = Legacy_IODescription("loss", []) - - return Legacy_ModelDescription( - [input_ids_desc, segment_ids_desc, input_mask_desc, masked_lm_labels_desc, next_sentence_labels_desc], - [loss_desc], - ) - - -def legacy_optim_params_a(name): - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6, "do_bias_correction": False} - - -def legacy_optim_params_b(name): - params = ["bert.embeddings.LayerNorm.bias", "bert.embeddings.LayerNorm.weight"] - if name in params: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6, "do_bias_correction": False} - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6, "do_bias_correction": False} - - -def legacy_optim_params_c(name): - params_group = optimizer_parameters(load_bert_onnx_model()) - if name in params_group[0]["params"]: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6, "do_bias_correction": False} - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6, "do_bias_correction": False} - - -############################################################################### -# Testing starts here ######################################################### -############################################################################### - - -@pytest.mark.parametrize("dynamic_shape", [(True), (False)]) -def testToyBERTModelBasicTraining(dynamic_shape): - model_desc = bert_model_description(dynamic_shape) - model = load_bert_onnx_model() - - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions({}) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - for _i in range(10): - sample_input = generate_random_input_from_model_desc(model_desc) - output = trainer.train_step(*sample_input) - assert output.shape == torch.Size([]) - - -@pytest.mark.parametrize( - "expected_losses", - [([11.041123, 10.986166, 11.101636, 11.013366, 11.03775, 11.041175, 10.957118, 11.069563, 11.040824, 11.16437])], -) -def testToyBERTDeterministicCheck(expected_losses): - # Common setup - train_steps = 10 - device = "cuda" - seed = 1 - rtol = 1e-3 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # Modeling - model_desc = bert_model_description() - model = load_bert_onnx_model() - optimizer_parameters(model) - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train - experimental_losses = [] - for i in range(train_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - - # Check output - _test_helpers.assert_model_outputs(experimental_losses, expected_losses, rtol=rtol) - - -@pytest.mark.parametrize( - "initial_lr, lr_scheduler, expected_learning_rates, expected_losses", - [ - ( - 1.0, - optim.lr_scheduler.ConstantWarmupLRScheduler, - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - [ - 10.988012313842773, - 10.99213981628418, - 120.79301452636719, - 36.11647033691406, - 95.83200073242188, - 221.2766571044922, - 208.40316772460938, - 279.5332946777344, - 402.46380615234375, - 325.79254150390625, - ], - ), - ( - 0.5, - optim.lr_scheduler.ConstantWarmupLRScheduler, - [0.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], - [ - 10.988012313842773, - 10.99213981628418, - 52.69743347167969, - 19.741533279418945, - 83.88340759277344, - 126.39848327636719, - 91.53898620605469, - 63.62016296386719, - 102.21206665039062, - 180.1424560546875, - ], - ), - ( - 1.0, - optim.lr_scheduler.CosineWarmupLRScheduler, - [ - 0.0, - 0.9931806517013612, - 0.9397368756032445, - 0.8386407858128706, - 0.7008477123264848, - 0.5412896727361662, - 0.37725725642960045, - 0.22652592093878665, - 0.10542974530180327, - 0.02709137914968268, - ], - [ - 10.988012313842773, - 10.99213981628418, - 120.6441650390625, - 32.152557373046875, - 89.63705444335938, - 138.8782196044922, - 117.57748413085938, - 148.01927185058594, - 229.60403442382812, - 110.2930908203125, - ], - ), - ( - 1.0, - optim.lr_scheduler.LinearWarmupLRScheduler, - [ - 0.0, - 0.9473684210526315, - 0.8421052631578947, - 0.7368421052631579, - 0.631578947368421, - 0.5263157894736842, - 0.42105263157894735, - 0.3157894736842105, - 0.21052631578947367, - 0.10526315789473684, - ], - [ - 10.988012313842773, - 10.99213981628418, - 112.89633178710938, - 31.114538192749023, - 80.94029235839844, - 131.34490966796875, - 111.4329605102539, - 133.74252319335938, - 219.37344360351562, - 109.67041015625, - ], - ), - ( - 1.0, - optim.lr_scheduler.PolyWarmupLRScheduler, - [ - 0.0, - 0.9473684263157895, - 0.8421052789473684, - 0.7368421315789474, - 0.6315789842105263, - 0.5263158368421054, - 0.42105268947368424, - 0.31578954210526317, - 0.21052639473684212, - 0.10526324736842106, - ], - [ - 10.988012313842773, - 10.99213981628418, - 112.89633178710938, - 31.114538192749023, - 80.9402847290039, - 131.3447265625, - 111.43253326416016, - 133.7415008544922, - 219.37147521972656, - 109.66986083984375, - ], - ), - ], -) -def testToyBERTModelLRScheduler(initial_lr, lr_scheduler, expected_learning_rates, expected_losses): - return # TODO: re-enable after nondeterminism on backend is fixed - # Common setup - device = "cuda" - total_steps = 10 - seed = 1 - warmup = 0.05 - cycles = 0.5 - power = 1.0 - lr_end = 1e-7 - rtol = 1e-3 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # Setup LR Schedulers - if ( - lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler - or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler - ): - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) - elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) - elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) - else: - raise RuntimeError("Invalid lr_scheduler") - - # Modeling - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.AdamConfig(lr=initial_lr) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "lr_scheduler": lr_scheduler, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train - losses = [] - learning_rates = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - losses.append(trainer.train_step(*sample_input).cpu().item()) - learning_rates.append(trainer.options.lr_scheduler.get_last_lr()[0]) - - # Check output - _test_helpers.assert_model_outputs(learning_rates, expected_learning_rates, rtol=rtol) - _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol) - - -@pytest.mark.parametrize( - "loss_scaler, expected_losses", - [ - ( - None, - [ - 11.041126, - 10.986309, - 11.101673, - 11.013394, - 11.037781, - 11.041253, - 10.957072, - 11.069506, - 11.040807, - 11.164349, - ], - ), - ( - amp.DynamicLossScaler(), - [ - 11.041126, - 10.986309, - 11.101673, - 11.013394, - 11.037781, - 11.041253, - 10.957072, - 11.069506, - 11.040807, - 11.164349, - ], - ), - ( - CustomLossScaler(), - [ - 11.041126, - 10.986309, - 11.101645, - 11.013412, - 11.037757, - 11.041273, - 10.957077, - 11.069525, - 11.040765, - 11.164298, - ], - ), - ], -) -def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses): - # Common setup - total_steps = 10 - device = "cuda" - seed = 1 - rtol = 1e-3 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # Modeling - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train - losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - losses.append(trainer.train_step(*sample_input).cpu().item()) - - # Check output - _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol) - - -@pytest.mark.parametrize( - "gradient_accumulation_steps, expected_losses", - [ - ( - 1, - [ - 11.041123, - 10.986166, - 11.101636, - 11.013366, - 11.03775, - 11.041175, - 10.957118, - 11.069563, - 11.040824, - 11.16437, - ], - ), - ( - 4, - [ - 11.041123, - 10.982856, - 11.105512, - 11.006721, - 11.03358, - 11.05058, - 10.955864, - 11.059035, - 11.037753, - 11.162649, - ], - ), - ( - 7, - [ - 11.041123, - 10.982856, - 11.105512, - 11.006721, - 11.036314, - 11.055109, - 10.960751, - 11.05809, - 11.038856, - 11.159635, - ], - ), - ], -) -def testToyBERTModelGradientAccumulation(gradient_accumulation_steps, expected_losses): - # Common setup - total_steps = 10 - device = "cuda" - seed = 1 - rtol = 1e-3 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # Modeling - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train - losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - losses.append(trainer.train_step(*sample_input).cpu().item()) - - # Check output - _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol) - - -def testToyBertCheckpointBasic(): - # Common setup - seed = 1 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions({"debug": {"deterministic_compute": True}}) - - # Create ORTTrainer and save initial state in a dict - model = load_bert_onnx_model() - model_desc = bert_model_description() - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - sd = trainer.state_dict() - - ## All initializers must be present in the state_dict - ## when the specified model for ORTTRainer is an ONNX model - for param in trainer._onnx_model.graph.initializer: - assert param.name in sd["model"]["full_precision"] - - ## Modify one of the state values and load into ORTTrainer - sd["model"]["full_precision"]["bert.encoder.layer.0.attention.output.LayerNorm.weight"] += 10 - trainer.load_state_dict(sd) - - ## Save a checkpoint - ckpt_dir = "testdata" - trainer.save_checkpoint(os.path.join(ckpt_dir, "bert_toy_save_test.ortcp")) - del trainer - del model - - # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer - model2 = load_bert_onnx_model() - model_desc2 = bert_model_description() - trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config, options=opts) - trainer2.load_checkpoint(os.path.join(ckpt_dir, "bert_toy_save_test.ortcp")) - loaded_sd = trainer2.state_dict() - - # Assert whether original state and the one loaded from checkpoint matches - _test_commons.assert_all_states_close_ort(sd, loaded_sd) - - -def testToyBertCheckpointFrozenWeights(): - # Common setup - seed = 1 - total_steps = 10 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "utils": {"frozen_weights": ["bert.encoder.layer.0.attention.self.value.weight"]}, - } - ) - - # Create ORTTrainer and save initial state in a dict - model = load_bert_onnx_model() - model_desc = bert_model_description() - optim_config = optim.LambConfig() - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train for a few steps - for _i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, seed) - _ = trainer.train_step(*sample_input) - sample_input = generate_random_input_from_model_desc(model_desc, seed + total_steps + 1) - # Evaluate once to get a base loss - loss = trainer.eval_step(*sample_input) - # Save checkpoint - state_dict = trainer.state_dict() - - # Load previous state into another instance of ORTTrainer - model2 = load_bert_onnx_model() - model_desc2 = bert_model_description() - optim_config2 = optim.LambConfig() - trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config2, options=opts) - trainer2.load_state_dict(state_dict) - # Evaluate once to get a base loss - ckpt_loss = trainer2.eval_step(*sample_input) - - # Must match as both trainers have the same dict state - assert_allclose(loss.cpu(), ckpt_loss.cpu()) - loaded_state_dict = trainer2.state_dict() - _test_commons.assert_all_states_close_ort(state_dict, loaded_state_dict) - - -@pytest.mark.parametrize( - "optimizer, mixedprecision_enabled", - [ - (optim.LambConfig(), False), - (optim.AdamConfig(), False), - (optim.LambConfig(), True), - (optim.AdamConfig(), True), - ], -) -def testToyBertLoadOptimState(optimizer, mixedprecision_enabled): - # Common setup - device = "cuda" - seed = 1 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - optim_config = optimizer - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": {"id": device}, - "mixed_precision": { - "enabled": mixedprecision_enabled, - }, - "distributed": {"allreduce_post_accumulation": True}, - } - ) - - # Create ORTTrainer and save initial state in a dict - model = load_bert_onnx_model() - model_desc = bert_model_description() - dummy_init_state = _test_commons.generate_dummy_optim_state(model, optimizer) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - trainer.load_state_dict(dummy_init_state) - - # Expected values - input_ids = torch.tensor( - [ - [26598], - [21379], - [19922], - [5219], - [5644], - [20559], - [23777], - [25672], - [22969], - [16824], - [16822], - [635], - [27399], - [20647], - [18519], - [15546], - ], - device=device, - ) - segment_ids = torch.tensor( - [[0], [1], [0], [1], [0], [0], [1], [0], [0], [1], [1], [0], [0], [1], [1], [1]], device=device - ) - input_mask = torch.tensor( - [[0], [0], [0], [0], [1], [1], [1], [0], [1], [1], [0], [0], [0], [1], [0], [0]], device=device - ) - masked_lm_labels = torch.tensor( - [ - [25496], - [16184], - [11005], - [16228], - [14884], - [21660], - [8678], - [23083], - [4027], - [8397], - [11921], - [1333], - [26482], - [1666], - [17925], - [27978], - ], - device=device, - ) - next_sentence_labels = torch.tensor([0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device) - - # Actual values - _ = trainer.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels) - - actual_state_dict = trainer.state_dict() - del actual_state_dict["model"] - _test_commons.assert_all_states_close_ort(actual_state_dict, dummy_init_state) - - -@pytest.mark.parametrize( - "model_params", - [ - (["bert.embeddings.LayerNorm.bias"]), - ( - [ - "bert.embeddings.LayerNorm.bias", - "bert.embeddings.LayerNorm.weight", - "bert.encoder.layer.0.attention.output.LayerNorm.bias", - ] - ), - ], -) -def testORTTrainerFrozenWeights(model_params): - device = "cuda" - total_steps = 10 - seed = 1 - - # EXPERIMENTAL API - model_desc = bert_model_description() - model = load_bert_onnx_model() - - optim_config = optim.LambConfig() - # Setup ORTTrainer WITHOUT frozen weights - opts_dict = { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - opts = orttrainer.ORTTrainerOptions(opts_dict) - - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - trainer.train_step(*sample_input) - - # All model_params must be in the session state - assert trainer._onnx_model is not None - session_state = trainer._training_session.get_state() - assert all([param in session_state for param in model_params]) - - # Setup ORTTrainer WITH frozen weights - opts_dict.update({"utils": {"frozen_weights": model_params}}) - opts = orttrainer.ORTTrainerOptions(opts_dict) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - trainer.train_step(*sample_input) - - # All model_params CANNOT be in the session state - assert trainer._onnx_model is not None - session_state = trainer._training_session.get_state() - assert not any([param in session_state for param in model_params]) - - -def testToyBERTSaveAsONNX(): - device = "cuda" - onnx_file_name = "_____temp_toy_bert_onnx_model.onnx" - if os.path.exists(onnx_file_name): - os.remove(onnx_file_name) - assert not os.path.exists(onnx_file_name) - - # Load trainer - model_desc = bert_model_description() - model = load_bert_onnx_model() - - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - ) - - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - trainer.save_as_onnx(onnx_file_name) - assert os.path.exists(onnx_file_name) - - with open(onnx_file_name, "rb") as f: - bin_str = f.read() - reload_onnx_model = onnx.load_model_from_string(bin_str) - os.remove(onnx_file_name) - - # Create a new trainer from persisted ONNX model and compare with original ONNX model - trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config, options=opts) - assert trainer_from_onnx._onnx_model is not None - assert id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model) - for initializer, loaded_initializer in zip( - trainer._onnx_model.graph.initializer, trainer_from_onnx._onnx_model.graph.initializer - ): - assert initializer.name == loaded_initializer.name - assert onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph( - trainer._onnx_model.graph - ) - _test_helpers.assert_onnx_weights(trainer, trainer_from_onnx) - - -############################################################################### -# Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############ -############################################################################### -@pytest.mark.parametrize( - "optimizer_config", - [ - (optim.AdamConfig), - # (optim.LambConfig), # TODO: re-enable after nondeterminism on backend is fixed - (optim.SGDConfig), - ], -) -def testToyBERTModelLegacyExperimentalBasicTraining(optimizer_config): - # Common setup - train_steps = 512 - - device = "cuda" - seed = 1 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # EXPERIMENTAL API - model_desc = bert_model_description() - model = load_bert_onnx_model() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - ) - optim_config = optimizer_config(lr=0.01) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - experimental_losses = [] - for i in range(train_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - if optimizer_config == optim.AdamConfig: - legacy_optimizer = "AdamOptimizer" - elif optimizer_config == optim.LambConfig: - legacy_optimizer = "LambOptimizer" - elif optimizer_config == optim.SGDConfig: - legacy_optimizer = "SGDOptimizer" - else: - raise RuntimeError("Invalid optimizer_config") - - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(lr=optim_config.lr) - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - legacy_optimizer, - None, - learning_rate_description, - device, - _use_deterministic_compute=True, - ) - legacy_losses = [] - for i in range(train_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) - legacy_losses.append(leg_loss.cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses, True) - - -@pytest.mark.parametrize( - "initial_lr, lr_scheduler, legacy_lr_scheduler", - [ - (1.0, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler), - (0.5, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler), - (1.0, optim.lr_scheduler.CosineWarmupLRScheduler, _test_commons.legacy_cosine_lr_scheduler), - (1.0, optim.lr_scheduler.LinearWarmupLRScheduler, _test_commons.legacy_linear_lr_scheduler), - (1.0, optim.lr_scheduler.PolyWarmupLRScheduler, _test_commons.legacy_poly_lr_scheduler), - ], -) -def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, legacy_lr_scheduler): - ############################################################################ - # These tests require hard-coded values for 'total_steps' and 'initial_lr' # - ############################################################################ - - # Common setup - total_steps = 128 - device = "cuda" - seed = 1 - warmup = 0.05 - cycles = 0.5 - power = 1.0 - lr_end = 1e-7 - - # Setup both Experimental and Legacy LR Schedulers before the experimental loop - if ( - legacy_lr_scheduler == _test_commons.legacy_constant_lr_scheduler - or legacy_lr_scheduler == _test_commons.legacy_linear_lr_scheduler - ): - legacy_lr_scheduler = partial( - legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup - ) - elif legacy_lr_scheduler == _test_commons.legacy_cosine_lr_scheduler: - legacy_lr_scheduler = partial( - legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup, cycles=cycles - ) - elif legacy_lr_scheduler == _test_commons.legacy_poly_lr_scheduler: - legacy_lr_scheduler = partial( - legacy_lr_scheduler, - initial_lr=initial_lr, - total_steps=total_steps, - warmup=warmup, - power=power, - lr_end=lr_end, - ) - else: - raise RuntimeError("Invalid legacy_lr_scheduler") - if ( - lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler - or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler - ): - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) - elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) - elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) - else: - raise RuntimeError("Invalid lr_scheduler") - - # EXPERIMENTAL API - model_desc = bert_model_description() - model = load_bert_onnx_model() - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - optim_config = optim.AdamConfig(lr=initial_lr) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "lr_scheduler": lr_scheduler, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - experimental_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - assert_allclose(trainer.options.lr_scheduler.get_last_lr()[0], legacy_lr_scheduler(i)) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(initial_lr) - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - "AdamOptimizer", - None, - learning_rate_description, - device, - _use_deterministic_compute=True, - get_lr_this_step=legacy_lr_scheduler, - ) - legacy_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - leg_loss = legacy_trainer.train_step(*sample_input) - legacy_losses.append(leg_loss.cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) - - -@pytest.mark.parametrize( - "loss_scaler, legacy_loss_scaler", - [ - (None, Legacy_LossScaler("ort_test_input_loss_scaler", True)), - (amp.DynamicLossScaler(), Legacy_LossScaler("ort_test_input_loss_scaler", True)), - (CustomLossScaler(), LegacyCustomLossScaler()), - ], -) -def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental(loss_scaler, legacy_loss_scaler): - # Common setup - total_steps = 128 - device = "cuda" - seed = 1 - - # EXPERIMENTAL IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.AdamConfig(lr=0.001) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - experimental_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(optim_config.lr) - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - "AdamOptimizer", - None, - learning_rate_description, - device, - _use_deterministic_compute=True, - use_mixed_precision=True, - loss_scaler=legacy_loss_scaler, - ) - legacy_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) - legacy_losses.append(leg_loss.cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) - - -@pytest.mark.parametrize("gradient_accumulation_steps", [(1), (4), (7)]) -def testToyBERTModelGradientAccumulationLegacyExperimental(gradient_accumulation_steps): - # Common setup - total_steps = 128 - device = "cuda" - seed = 1 - - # EXPERIMENTAL IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.AdamConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - experimental_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - loss = trainer.train_step(*sample_input) - experimental_losses.append(loss.cpu().item()) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(optim_config.lr) - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - "AdamOptimizer", - None, - learning_rate_description, - device, - _use_deterministic_compute=True, - gradient_accumulation_steps=gradient_accumulation_steps, - ) - legacy_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) - legacy_losses.append(leg_loss.cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) - - -@pytest.mark.parametrize( - "params, legacy_optim_map", - [ - # Change the hyper parameters for all parameters - ([], legacy_optim_params_a), - # Change the hyperparameters for a subset of hardcoded parameters - ( - [ - { - "params": ["bert.embeddings.LayerNorm.bias", "bert.embeddings.LayerNorm.weight"], - "alpha": 0.9, - "beta": 0.999, - "lambda_coef": 0.0, - "epsilon": 1e-6, - "do_bias_correction": False, - } - ], - legacy_optim_params_b, - ), - # Change the hyperparameters for a generated set of paramers - (optimizer_parameters(load_bert_onnx_model()), legacy_optim_params_c), - ], -) -def testToyBERTModelLegacyExperimentalCustomOptimParameters(params, legacy_optim_map): - # Common setup - total_steps = 128 - device = "cuda" - seed = 1 - - # EXPERIMENTAL API - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - model_desc = bert_model_description() - model = load_bert_onnx_model() - - optim_config = optim.AdamConfig( - params, alpha=0.9, beta=0.999, lambda_coef=0.01, epsilon=1e-6, do_bias_correction=False - ) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - experimental_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(trainer.optim_config.lr) - - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - "AdamOptimizer", - legacy_optim_map, - learning_rate_description, - device, - _use_deterministic_compute=True, - ) - legacy_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - legacy_sample_input = [*sample_input, learning_rate] - legacy_losses.append(legacy_trainer.train_step(legacy_sample_input).cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py deleted file mode 100644 index d366f2cb26557..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py +++ /dev/null @@ -1,722 +0,0 @@ -from unittest.mock import Mock, patch - -import numpy as np -import onnx -import pytest -import torch -from _test_commons import _load_pytorch_transformer_model - -from onnxruntime.training import _checkpoint_storage, amp, checkpoint, optim, orttrainer # noqa: F401 - -# Helper functions - - -def _create_trainer(zero_enabled=False): - """Cerates a simple ORTTrainer for ORTTrainer functional tests""" - - device = "cuda" - optim_config = optim.LambConfig(lr=0.1) - opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} - if zero_enabled: - opts["distributed"] = { - "world_rank": 0, - "world_size": 1, - "horizontal_parallel_size": 1, - "data_parallel_size": 1, - "allreduce_post_accumulation": True, - "deepspeed_zero_optimization": {"stage": 1}, - } - model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) - trainer = orttrainer.ORTTrainer( - model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(opts) - ) - - return trainer - - -class _training_session_mock: # noqa: N801 - """Mock object for the ORTTrainer _training_session member""" - - def __init__(self, model_states, optimizer_states, partition_info): - self.model_states = model_states - self.optimizer_states = optimizer_states - self.partition_info = partition_info - - def get_model_state(self, include_mixed_precision_weights=False): - return self.model_states - - def get_optimizer_state(self): - return self.optimizer_states - - def get_partition_info_map(self): - return self.partition_info - - -def _get_load_state_dict_strict_error_arguments(): - """Return a list of tuples that can be used as parameters for test_load_state_dict_errors_when_model_key_missing - - Construct a list of tuples (training_session_state_dict, input_state_dict, error_arguments) - The load_state_dict function will compare the two state dicts (training_session_state_dict, input_state_dict) and - throw a runtime error with the missing/unexpected keys. The error arguments capture these missing/unexpected keys. - """ - - training_session_state_dict = { - "model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}, - "optimizer": { - "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(5)}, - }, - } - - # input state dictionaries - precision_key_missing = {"model": {}, "optimizer": {}} - precision_key_unexpected = {"model": {"full_precision": {}, "mixed_precision": {}}, "optimizer": {}} - model_state_key_missing = {"model": {"full_precision": {}}, "optimizer": {}} - model_state_key_unexpected = {"model": {"full_precision": {"a": 2, "b": 3, "c": 4}}, "optimizer": {}} - optimizer_model_state_key_missing = {"model": {"full_precision": {"a": 2, "b": 3}}, "optimizer": {}} - optimizer_model_state_key_unexpected = { - "model": {"full_precision": {"a": 2, "b": 3}}, - "optimizer": {"a": {}, "shared_optimizer_state": {}, "b": {}}, - } - optimizer_state_key_missing = { - "model": {"full_precision": {"a": 2, "b": 3}}, - "optimizer": {"a": {}, "shared_optimizer_state": {"step": np.arange(5)}}, - } - optimizer_state_key_unexpected = { - "model": {"full_precision": {"a": 2, "b": 3}}, - "optimizer": { - "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(5), "another_step": np.arange(1)}, - }, - } - - input_arguments = [ - (training_session_state_dict, precision_key_missing, ["full_precision"]), - (training_session_state_dict, precision_key_unexpected, ["mixed_precision"]), - (training_session_state_dict, model_state_key_missing, ["a", "b"]), - (training_session_state_dict, model_state_key_unexpected, ["c"]), - (training_session_state_dict, optimizer_model_state_key_missing, ["a", "shared_optimizer_state"]), - (training_session_state_dict, optimizer_model_state_key_unexpected, ["b"]), - (training_session_state_dict, optimizer_state_key_missing, ["Moment_1", "Moment_2"]), - (training_session_state_dict, optimizer_state_key_unexpected, ["another_step"]), - ] - - return input_arguments - - -# Tests - - -def test_empty_state_dict_when_training_session_uninitialized(): - trainer = _create_trainer() - with pytest.warns(UserWarning) as user_warning: - state_dict = trainer.state_dict() - - assert len(state_dict.keys()) == 0 - assert ( - user_warning[0].message.args[0] == "ONNX Runtime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling ORTTrainer.state_dict()." - ) - - -@patch("onnx.ModelProto") -def test_training_session_provides_empty_model_states(onnx_model_mock): - trainer = _create_trainer() - training_session_mock = _training_session_mock({}, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert len(state_dict["model"].keys()) == 0 - - -@patch("onnx.ModelProto") -def test_training_session_provides_model_states(onnx_model_mock): - trainer = _create_trainer() - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - training_session_mock = _training_session_mock(model_states, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all() - assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all() - - -@patch("onnx.ModelProto") -def test_training_session_provides_model_states_pytorch_format(onnx_model_mock): - trainer = _create_trainer() - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - training_session_mock = _training_session_mock(model_states, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict(pytorch_format=True) - assert torch.all(torch.eq(state_dict["a"], torch.tensor(np.arange(5)))) - assert torch.all(torch.eq(state_dict["b"], torch.tensor(np.arange(7)))) - - -@patch("onnx.ModelProto") -def test_onnx_graph_provides_frozen_model_states(onnx_model_mock): - trainer = _create_trainer() - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - training_session_mock = _training_session_mock(model_states, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - trainer.options.utils.frozen_weights = ["a_frozen_weight", "a_float16_weight"] - trainer._onnx_model.graph.initializer = [ - onnx.numpy_helper.from_array(np.array([1, 2, 3], dtype=np.float32), "a_frozen_weight"), - onnx.numpy_helper.from_array(np.array([4, 5, 6], dtype=np.float32), "a_non_fronzen_weight"), - onnx.numpy_helper.from_array(np.array([7, 8, 9], dtype=np.float16), "a_float16_weight"), - ] - - state_dict = trainer.state_dict() - assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all() - assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all() - assert (state_dict["model"]["full_precision"]["a_frozen_weight"] == np.array([1, 2, 3], dtype=np.float32)).all() - assert "a_non_fronzen_weight" not in state_dict["model"]["full_precision"] - assert (state_dict["model"]["full_precision"]["a_float16_weight"] == np.array([7, 8, 9], dtype=np.float32)).all() - - -@patch("onnx.ModelProto") -def test_training_session_provides_empty_optimizer_states(onnx_model_mock): - trainer = _create_trainer() - training_session_mock = _training_session_mock({}, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert len(state_dict["optimizer"].keys()) == 0 - - -@patch("onnx.ModelProto") -def test_training_session_provides_optimizer_states(onnx_model_mock): - trainer = _create_trainer() - optimizer_states = { - "model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(1)}, - } - training_session_mock = _training_session_mock({}, optimizer_states, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert (state_dict["optimizer"]["model_weight"]["Moment_1"] == np.arange(5)).all() - assert (state_dict["optimizer"]["model_weight"]["Moment_2"] == np.arange(7)).all() - assert (state_dict["optimizer"]["shared_optimizer_state"]["step"] == np.arange(1)).all() - - -@patch("onnx.ModelProto") -def test_training_session_provides_optimizer_states_pytorch_format(onnx_model_mock): - trainer = _create_trainer() - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - optimizer_states = { - "model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(1)}, - } - training_session_mock = _training_session_mock(model_states, optimizer_states, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict(pytorch_format=True) - assert "optimizer" not in state_dict - - -@patch("onnx.ModelProto") -def test_training_session_provides_empty_partition_info_map(onnx_model_mock): - trainer = _create_trainer(zero_enabled=True) - training_session_mock = _training_session_mock({}, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert len(state_dict["partition_info"].keys()) == 0 - - -@patch("onnx.ModelProto") -def test_training_session_provides_partition_info_map(onnx_model_mock): - trainer = _create_trainer(zero_enabled=True) - partition_info = {"a": {"original_dim": [1, 2, 3]}} - training_session_mock = _training_session_mock({}, {}, partition_info) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert state_dict["partition_info"]["a"]["original_dim"] == [1, 2, 3] - - -@patch("onnx.ModelProto") -def test_training_session_provides_all_states(onnx_model_mock): - trainer = _create_trainer(zero_enabled=True) - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - optimizer_states = { - "model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(1)}, - } - partition_info = {"a": {"original_dim": [1, 2, 3]}} - training_session_mock = _training_session_mock(model_states, optimizer_states, partition_info) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all() - assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all() - assert (state_dict["optimizer"]["model_weight"]["Moment_1"] == np.arange(5)).all() - assert (state_dict["optimizer"]["model_weight"]["Moment_2"] == np.arange(7)).all() - assert (state_dict["optimizer"]["shared_optimizer_state"]["step"] == np.arange(1)).all() - assert state_dict["partition_info"]["a"]["original_dim"] == [1, 2, 3] - - -def test_load_state_dict_holds_when_training_session_not_initialized(): - trainer = _create_trainer() - state_dict = { - "model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}, - "optimizer": { - "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(5)}, - }, - } - assert not trainer._load_state_dict - state_dict = trainer.load_state_dict(state_dict) - assert trainer._load_state_dict - - -@pytest.mark.parametrize( - "state_dict, input_state_dict, error_key", - [ - ( - {"model": {}, "optimizer": {}}, - {"model": {}, "optimizer": {}, "trainer_options": {"optimizer_name": "LambOptimizer"}}, - "train_step_info", - ), - ( - {"optimizer": {}, "train_step_info": {"optimization_step": 0, "step": 0}}, - { - "optimizer": {}, - "trainer_options": {"optimizer_name": "LambOptimizer"}, - "train_step_info": {"optimization_step": 0, "step": 0}, - }, - "model", - ), - ( - {"model": {}, "train_step_info": {"optimization_step": 0, "step": 0}}, - { - "model": {}, - "trainer_options": {"optimizer_name": "LambOptimizer"}, - "train_step_info": {"optimization_step": 0, "step": 0}, - }, - "optimizer", - ), - ], -) -def test_load_state_dict_warns_when_model_optimizer_key_missing(state_dict, input_state_dict, error_key): - trainer = _create_trainer() - trainer._training_session = _training_session_mock({}, {}, {}) - trainer.state_dict = Mock(return_value=state_dict) - trainer._update_onnx_model_initializers = Mock() - trainer._init_session = Mock() - with patch("onnx.ModelProto") as onnx_model_mock: - trainer._onnx_model = onnx_model_mock() - trainer._onnx_model.graph.initializer = [] - with pytest.warns(UserWarning) as user_warning: - trainer.load_state_dict(input_state_dict) - - assert user_warning[0].message.args[0] == f"Missing key: {error_key} in state_dict" - - -@pytest.mark.parametrize("state_dict, input_state_dict, error_keys", _get_load_state_dict_strict_error_arguments()) -def test_load_state_dict_errors_when_state_dict_mismatch(state_dict, input_state_dict, error_keys): - trainer = _create_trainer() - trainer._training_session = _training_session_mock({}, {}, {}) - trainer.state_dict = Mock(return_value=state_dict) - with pytest.raises(RuntimeError) as runtime_error: - trainer.load_state_dict(input_state_dict) - - assert any(key in str(runtime_error.value) for key in error_keys) - - -@patch("onnx.ModelProto") -def test_load_state_dict_loads_the_states_and_inits_training_session(onnx_model_mock): - trainer = _create_trainer() - training_session_state_dict = { - "model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}, - "optimizer": { - "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(1)}, - }, - } - - input_state_dict = { - "model": {"full_precision": {"a": np.array([1, 2]), "b": np.array([3, 4])}}, - "optimizer": { - "a": {"Moment_1": np.array([5, 6]), "Moment_2": np.array([7, 8])}, - "shared_optimizer_state": {"step": np.array([9])}, - }, - "trainer_options": {"optimizer_name": "LambOptimizer"}, - } - trainer._training_session = _training_session_mock({}, {}, {}) - trainer.state_dict = Mock(return_value=training_session_state_dict) - trainer._onnx_model = onnx_model_mock() - trainer._onnx_model.graph.initializer = [ - onnx.numpy_helper.from_array(np.arange(20, dtype=np.float32), "a"), - onnx.numpy_helper.from_array(np.arange(25, dtype=np.float32), "b"), - ] - trainer._update_onnx_model_initializers = Mock() - trainer._init_session = Mock() - - trainer.load_state_dict(input_state_dict) - - loaded_initializers, _ = trainer._update_onnx_model_initializers.call_args - state_dict_to_load, _ = trainer._init_session.call_args - - assert "a" in loaded_initializers[0] - assert (loaded_initializers[0]["a"] == np.array([1, 2])).all() - assert "b" in loaded_initializers[0] - assert (loaded_initializers[0]["b"] == np.array([3, 4])).all() - - assert (state_dict_to_load[0]["a"]["Moment_1"] == np.array([5, 6])).all() - assert (state_dict_to_load[0]["a"]["Moment_2"] == np.array([7, 8])).all() - assert (state_dict_to_load[0]["shared_optimizer_state"]["step"] == np.array([9])).all() - - -@patch("onnxruntime.training._checkpoint_storage.save") -def test_save_checkpoint_calls_checkpoint_storage_save(save_mock): - trainer = _create_trainer() - state_dict = {"model": {}, "optimizer": {}} - trainer.state_dict = Mock(return_value=state_dict) - - trainer.save_checkpoint("abc") - - save_args, _ = save_mock.call_args - assert "model" in save_args[0] - assert not bool(save_args[0]["model"]) - assert "optimizer" in save_args[0] - assert not bool(save_args[0]["optimizer"]) - assert save_args[1] == "abc" - - -@patch("onnxruntime.training._checkpoint_storage.save") -def test_save_checkpoint_exclude_optimizer_states(save_mock): - trainer = _create_trainer() - state_dict = {"model": {}, "optimizer": {}} - trainer.state_dict = Mock(return_value=state_dict) - - trainer.save_checkpoint("abc", include_optimizer_states=False) - - save_args, _ = save_mock.call_args - assert "model" in save_args[0] - assert not bool(save_args[0]["model"]) - assert "optimizer" not in save_args[0] - assert save_args[1] == "abc" - - -@patch("onnxruntime.training._checkpoint_storage.save") -def test_save_checkpoint_user_dict(save_mock): - trainer = _create_trainer() - state_dict = {"model": {}, "optimizer": {}} - trainer.state_dict = Mock(return_value=state_dict) - - trainer.save_checkpoint("abc", user_dict={"abc": np.arange(4)}) - - save_args, _ = save_mock.call_args - assert "user_dict" in save_args[0] - assert save_args[0]["user_dict"] == _checkpoint_storage.to_serialized_hex({"abc": np.arange(4)}) - - -@patch("onnxruntime.training._checkpoint_storage.load") -@patch("onnxruntime.training.checkpoint.aggregate_checkpoints") -def test_load_checkpoint(aggregate_checkpoints_mock, load_mock): - trainer = _create_trainer() - trainer_options = { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - } - state_dict = { - "model": {}, - "optimizer": {}, - "trainer_options": { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - }, - } - trainer.load_state_dict = Mock() - - load_mock.side_effect = [trainer_options, state_dict] - trainer.load_checkpoint("abc") - - args_list = load_mock.call_args_list - load_args, load_kwargs = args_list[0] - assert load_args[0] == "abc" - assert load_kwargs["key"] == "trainer_options" - load_args, load_kwargs = args_list[1] - assert load_args[0] == "abc" - assert "key" not in load_kwargs - assert not aggregate_checkpoints_mock.called - - -@patch("onnxruntime.training._checkpoint_storage.load") -@patch("onnxruntime.training.checkpoint.aggregate_checkpoints") -@pytest.mark.parametrize( - "trainer_options", - [ - { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(4), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(4), - "zero_stage": np.int64(1), - }, - { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(1), - }, - { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(1), - }, - ], -) -def test_load_checkpoint_aggregation_required_zero_enabled(aggregate_checkpoints_mock, load_mock, trainer_options): - trainer = _create_trainer() - trainer.load_state_dict = Mock() - - load_mock.side_effect = [trainer_options] - trainer.load_checkpoint("abc") - - args_list = load_mock.call_args_list - load_args, load_kwargs = args_list[0] - assert load_args[0] == "abc" - assert load_kwargs["key"] == "trainer_options" - assert aggregate_checkpoints_mock.called - call_args, _ = aggregate_checkpoints_mock.call_args - assert call_args[0] == tuple(["abc"]) - - -@patch("onnxruntime.training._checkpoint_storage.load") -@patch("onnxruntime.training.checkpoint.aggregate_checkpoints") -def test_load_checkpoint_user_dict(aggregate_checkpoints_mock, load_mock): - trainer = _create_trainer() - trainer_options = { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - } - state_dict = { - "model": {}, - "optimizer": {}, - "trainer_options": { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - }, - "user_dict": _checkpoint_storage.to_serialized_hex({"array": torch.tensor(np.arange(5))}), - } - trainer.load_state_dict = Mock() - - load_mock.side_effect = [trainer_options, state_dict] - user_dict = trainer.load_checkpoint("abc") - - assert torch.all(torch.eq(user_dict["array"], torch.tensor(np.arange(5)))) - - -@patch("onnxruntime.training._checkpoint_storage.load") -def test_checkpoint_aggregation(load_mock): - trainer_options1 = { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(2), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(2), - "zero_stage": np.int64(1), - "optimizer_name": b"Adam", - } - trainer_options2 = { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(1), - "world_size": np.int64(2), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(2), - "zero_stage": np.int64(1), - "optimizer_name": b"Adam", - } - - state_dict1 = { - "model": {"full_precision": {"optimizer_sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}}, - "optimizer": { - "optimizer_sharded": { - "Moment_1": np.array([9, 8, 7]), - "Moment_2": np.array([99, 88, 77]), - "Step": np.array([5]), - }, - "non_sharded": { - "Moment_1": np.array([666, 555, 444]), - "Moment_2": np.array([6666, 5555, 4444]), - "Step": np.array([55]), - }, - }, - "trainer_options": { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - "optimizer_name": b"Adam", - }, - "partition_info": {"optimizer_sharded": {"original_dim": np.array([2, 3])}}, - } - - state_dict2 = { - "model": {"full_precision": {"optimizer_sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}}, - "optimizer": { - "optimizer_sharded": { - "Moment_1": np.array([6, 5, 4]), - "Moment_2": np.array([66, 55, 44]), - "Step": np.array([5]), - }, - "non_sharded": { - "Moment_1": np.array([666, 555, 444]), - "Moment_2": np.array([6666, 5555, 4444]), - "Step": np.array([55]), - }, - }, - "trainer_options": { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(1), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - "optimizer_name": b"Adam", - }, - "partition_info": {"optimizer_sharded": {"original_dim": np.array([2, 3])}}, - } - - load_mock.side_effect = [trainer_options1, trainer_options2, trainer_options1, state_dict1, state_dict2] - state_dict = checkpoint.aggregate_checkpoints(["abc", "def"], pytorch_format=False) - - assert (state_dict["model"]["full_precision"]["optimizer_sharded"] == np.array([1, 2, 3])).all() - assert (state_dict["model"]["full_precision"]["non_sharded"] == np.array([11, 22, 33])).all() - assert (state_dict["optimizer"]["optimizer_sharded"]["Moment_1"] == np.array([[9, 8, 7], [6, 5, 4]])).all() - assert (state_dict["optimizer"]["optimizer_sharded"]["Moment_2"] == np.array([[99, 88, 77], [66, 55, 44]])).all() - assert (state_dict["optimizer"]["optimizer_sharded"]["Step"] == np.array([5])).all() - assert (state_dict["optimizer"]["non_sharded"]["Moment_1"] == np.array([666, 555, 444])).all() - assert (state_dict["optimizer"]["non_sharded"]["Moment_2"] == np.array([6666, 5555, 4444])).all() - assert (state_dict["optimizer"]["non_sharded"]["Step"] == np.array([55])).all() - - assert state_dict["trainer_options"]["mixed_precision"] is False - assert state_dict["trainer_options"]["world_rank"] == 0 - assert state_dict["trainer_options"]["world_size"] == 1 - assert state_dict["trainer_options"]["horizontal_parallel_size"] == 1 - assert state_dict["trainer_options"]["data_parallel_size"] == 1 - assert state_dict["trainer_options"]["zero_stage"] == 0 - assert state_dict["trainer_options"]["optimizer_name"] == b"Adam" - - -@patch("onnxruntime.training._checkpoint_storage.load") -def test_checkpoint_aggregation_mixed_precision(load_mock): - trainer_options1 = { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(0), - "world_size": np.int64(2), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(2), - "zero_stage": np.int64(1), - "optimizer_name": b"Adam", - } - trainer_options2 = { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(1), - "world_size": np.int64(2), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(2), - "zero_stage": np.int64(1), - "optimizer_name": b"Adam", - } - - state_dict1 = { - "model": {"full_precision": {"sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}}, - "optimizer": { - "sharded": {"Moment_1": np.array([9, 8, 7]), "Moment_2": np.array([99, 88, 77]), "Step": np.array([5])}, - "non_sharded": { - "Moment_1": np.array([666, 555, 444]), - "Moment_2": np.array([6666, 5555, 4444]), - "Step": np.array([55]), - }, - }, - "trainer_options": { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - "optimizer_name": b"Adam", - }, - "partition_info": {"sharded": {"original_dim": np.array([2, 3])}}, - } - - state_dict2 = { - "model": {"full_precision": {"sharded": np.array([4, 5, 6]), "non_sharded": np.array([11, 22, 33])}}, - "optimizer": { - "sharded": {"Moment_1": np.array([6, 5, 4]), "Moment_2": np.array([66, 55, 44]), "Step": np.array([5])}, - "non_sharded": { - "Moment_1": np.array([666, 555, 444]), - "Moment_2": np.array([6666, 5555, 4444]), - "Step": np.array([55]), - }, - }, - "trainer_options": { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(1), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - "optimizer_name": b"Adam", - }, - "partition_info": {"sharded": {"original_dim": np.array([2, 3])}}, - } - - load_mock.side_effect = [trainer_options1, trainer_options2, trainer_options1, state_dict1, state_dict2] - state_dict = checkpoint.aggregate_checkpoints(["abc", "def"], pytorch_format=False) - - assert (state_dict["model"]["full_precision"]["sharded"] == np.array([[1, 2, 3], [4, 5, 6]])).all() - assert (state_dict["model"]["full_precision"]["non_sharded"] == np.array([11, 22, 33])).all() - assert (state_dict["optimizer"]["sharded"]["Moment_1"] == np.array([[9, 8, 7], [6, 5, 4]])).all() - assert (state_dict["optimizer"]["sharded"]["Moment_2"] == np.array([[99, 88, 77], [66, 55, 44]])).all() - assert (state_dict["optimizer"]["sharded"]["Step"] == np.array([5])).all() - assert (state_dict["optimizer"]["non_sharded"]["Moment_1"] == np.array([666, 555, 444])).all() - assert (state_dict["optimizer"]["non_sharded"]["Moment_2"] == np.array([6666, 5555, 4444])).all() - assert (state_dict["optimizer"]["non_sharded"]["Step"] == np.array([55])).all() - - assert state_dict["trainer_options"]["mixed_precision"] is True - assert state_dict["trainer_options"]["world_rank"] == 0 - assert state_dict["trainer_options"]["world_size"] == 1 - assert state_dict["trainer_options"]["horizontal_parallel_size"] == 1 - assert state_dict["trainer_options"]["data_parallel_size"] == 1 - assert state_dict["trainer_options"]["zero_stage"] == 0 - assert state_dict["trainer_options"]["optimizer_name"] == b"Adam" diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py deleted file mode 100644 index fa13625f0ddac..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ /dev/null @@ -1,2460 +0,0 @@ -import inspect -import os -import tempfile -from functools import partial - -import _test_commons -import _test_helpers -import onnx -import pytest -import torch -import torch.nn.functional as F -from numpy.testing import assert_allclose -from packaging.version import Version as StrictVersion - -from onnxruntime import SessionOptions, set_seed -from onnxruntime.capi.ort_trainer import LossScaler as Legacy_LossScaler -from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer -from onnxruntime.training import PropagateCastOpsStrategy, TrainStepInfo, _utils, amp -from onnxruntime.training import model_desc_validation as md_val -from onnxruntime.training import optim, orttrainer, orttrainer_options - -############################################################################### -# Testing starts here ######################################################### -############################################################################### - -pytorch_110 = StrictVersion(".".join(torch.__version__.split(".")[:2])) >= StrictVersion("1.10.0") - - -def get_model_opset(model_onnx): - for op in model_onnx.opset_import: - if op.domain == "": - return op.version - return None - - -@pytest.mark.parametrize( - "test_input", - [({}), ({"batch": {}, "device": {}, "distributed": {}, "mixed_precision": {}, "utils": {}, "_internal_use": {}})], -) -def testORTTrainerOptionsDefaultValues(test_input): - """Test different ways of using default values for incomplete input""" - - expected_values = { - "batch": {"gradient_accumulation_steps": 1}, - "device": {"id": "cuda", "mem_limit": 0}, - "distributed": { - "world_rank": 0, - "world_size": 1, - "local_rank": 0, - "data_parallel_size": 1, - "horizontal_parallel_size": 1, - "pipeline_parallel": { - "pipeline_parallel_size": 1, - "num_pipeline_micro_batches": 1, - "pipeline_cut_info_string": "", - "sliced_schema": {}, - "sliced_axes": {}, - "sliced_tensor_names": [], - }, - "allreduce_post_accumulation": False, - "deepspeed_zero_optimization": { - "stage": 0, - }, - "enable_adasum": False, - }, - "lr_scheduler": None, - "mixed_precision": {"enabled": False, "loss_scaler": None}, - "graph_transformer": { - "attn_dropout_recompute": False, - "gelu_recompute": False, - "transformer_layer_recompute": False, - "number_recompute_layers": 0, - "propagate_cast_ops_config": {"strategy": PropagateCastOpsStrategy.FLOOD_FILL, "level": 1, "allow": []}, - }, - "utils": { - "frozen_weights": [], - "grad_norm_clip": True, - "memory_efficient_gradient": False, - "run_symbolic_shape_infer": False, - }, - "debug": { - "deterministic_compute": False, - "check_model_export": False, - "graph_save_paths": { - "model_after_graph_transforms_path": "", - "model_with_gradient_graph_path": "", - "model_with_training_graph_path": "", - "model_with_training_graph_after_optimization_path": "", - }, - }, - "_internal_use": { - "enable_internal_postprocess": True, - "extra_postprocess": None, - "onnx_opset_version": 14, - "enable_onnx_contrib_ops": True, - }, - "provider_options": {}, - "session_options": None, - } - - actual_values = orttrainer_options.ORTTrainerOptions(test_input) - assert actual_values._validated_opts == expected_values - - -@pytest.mark.parametrize( - "input,error_msg", - [ - ( - {"mixed_precision": {"enabled": 1}}, - "Invalid options: {'mixed_precision': [{'enabled': ['must be of boolean type']}]}", - ) - ], -) -def testORTTrainerOptionsInvalidMixedPrecisionEnabledSchema(input, error_msg): - """Test an invalid input based on schema validation error message""" - - with pytest.raises(ValueError) as e: - orttrainer_options.ORTTrainerOptions(input) - assert str(e.value) == error_msg - - -@pytest.mark.parametrize( - "input_dict,input_dtype,output_dtype", - [ - ( - {"inputs": [("in0", [])], "outputs": [("out0", []), ("out1", [])]}, - (torch.int,), - ( - torch.float, - torch.int32, - ), - ), - ({"inputs": [("in0", ["batch", 2, 3])], "outputs": [("out0", [], True)]}, (torch.int8,), (torch.int16,)), - ( - { - "inputs": [ - ("in0", []), - ("in1", [1]), - ("in2", [1, 2]), - ("in3", [1000, "dyn_ax1"]), - ("in4", ["dyn_ax1", "dyn_ax2", "dyn_ax3"]), - ], - "outputs": [("out0", [], True), ("out1", [1], False), ("out2", [1, "dyn_ax1", 3])], - }, - ( - torch.float, - torch.uint8, - torch.bool, - torch.double, - torch.half, - ), - (torch.float, torch.float, torch.int64), - ), - ], -) -def testORTTrainerModelDescValidSchemas(input_dict, input_dtype, output_dtype): - r"""Test different ways of using default values for incomplete input""" - - model_description = md_val._ORTTrainerModelDesc(input_dict) - - # Validating hard-coded learning rate description - assert model_description.learning_rate.name == md_val.LEARNING_RATE_IO_DESCRIPTION_NAME - assert model_description.learning_rate.shape == [1] - assert model_description.learning_rate.dtype == torch.float32 - - # Validating model description from user - for idx, i_desc in enumerate(model_description.inputs): - assert isinstance(i_desc, model_description._InputDescription) - assert len(i_desc) == 2 - assert input_dict["inputs"][idx][0] == i_desc.name - assert input_dict["inputs"][idx][1] == i_desc.shape - for idx, o_desc in enumerate(model_description.outputs): - assert isinstance(o_desc, model_description._OutputDescription) - assert len(o_desc) == 3 - assert input_dict["outputs"][idx][0] == o_desc.name - assert input_dict["outputs"][idx][1] == o_desc.shape - is_loss = input_dict["outputs"][idx][2] if len(input_dict["outputs"][idx]) == 3 else False - assert is_loss == o_desc.is_loss - - # Set all_finite name and check its description - model_description.all_finite = md_val.ALL_FINITE_IO_DESCRIPTION_NAME - assert model_description.all_finite.name == md_val.ALL_FINITE_IO_DESCRIPTION_NAME - assert model_description.all_finite.shape == [1] - assert model_description.all_finite.dtype == torch.bool - - # Set loss_scale_input and check its description - model_description.loss_scale_input = md_val.LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME - assert model_description.loss_scale_input.name == md_val.LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME - assert model_description.loss_scale_input.shape == [] - assert model_description.loss_scale_input.dtype == torch.float32 - - # Append type to inputs/outputs tuples - for idx, i_desc in enumerate(model_description.inputs): # noqa: B007 - model_description.add_type_to_input_description(idx, input_dtype[idx]) - for idx, o_desc in enumerate(model_description.outputs): # noqa: B007 - model_description.add_type_to_output_description(idx, output_dtype[idx]) - - # Verify inputs/outputs tuples are replaced by the typed counterparts - for idx, i_desc in enumerate(model_description.inputs): - assert isinstance(i_desc, model_description._InputDescriptionTyped) - assert input_dtype[idx] == i_desc.dtype - for idx, o_desc in enumerate(model_description.outputs): - assert isinstance(o_desc, model_description._OutputDescriptionTyped) - assert output_dtype[idx] == o_desc.dtype - - -@pytest.mark.parametrize( - "input_dict,error_msg", - [ - ( - {"inputs": [(True, [])], "outputs": [(True, [])]}, - "Invalid model_desc: {'inputs': [{0: ['the first element of the tuple (aka name) must be a string']}], " - "'outputs': [{0: ['the first element of the tuple (aka name) must be a string']}]}", - ), - ( - {"inputs": [("in1", None)], "outputs": [("out1", None)]}, - "Invalid model_desc: {'inputs': [{0: ['the second element of the tuple (aka shape) must be a list']}], " - "'outputs': [{0: ['the second element of the tuple (aka shape) must be a list']}]}", - ), - ( - {"inputs": [("in1", [])], "outputs": [("out1", [], None)]}, - "Invalid model_desc: {'outputs': [{0: ['the third element of the tuple (aka is_loss) must be a boolean']}]}", - ), - ( - {"inputs": [("in1", [True])], "outputs": [("out1", [True])]}, - "Invalid model_desc: {'inputs': [{0: ['each shape must be either a string or integer']}], " - "'outputs': [{0: ['each shape must be either a string or integer']}]}", - ), - ( - {"inputs": [("in1", [])], "outputs": [("out1", [], True), ("out2", [], True)]}, - "Invalid model_desc: {'outputs': [{1: ['only one is_loss can bet set to True']}]}", - ), - ( - {"inputz": [("in1", [])], "outputs": [("out1", [], True)]}, - "Invalid model_desc: {'inputs': ['required field'], 'inputz': ['unknown field']}", - ), - ( - {"inputs": [("in1", [])], "outputz": [("out1", [], True)]}, - "Invalid model_desc: {'outputs': ['required field'], 'outputz': ['unknown field']}", - ), - ], -) -def testORTTrainerModelDescInvalidSchemas(input_dict, error_msg): - r"""Test different ways of using default values for incomplete input""" - with pytest.raises(ValueError) as e: - md_val._ORTTrainerModelDesc(input_dict) - assert str(e.value) == error_msg - - -def testDynamicLossScaler(): - rtol = 1e-7 - default_scaler = amp.loss_scaler.DynamicLossScaler() - - # Initial state - train_step_info = orttrainer.TrainStepInfo(optim.LambConfig()) - assert_allclose(default_scaler.loss_scale, float(1 << 16), rtol=rtol, err_msg="loss scale mismatch") - assert default_scaler.up_scale_window == 2000 - assert_allclose(default_scaler.min_loss_scale, 1.0, rtol=rtol, err_msg="min loss scale mismatch") - assert_allclose(default_scaler.max_loss_scale, float(1 << 24), rtol=rtol, err_msg="max loss scale mismatch") - - # Performing 9*2000 updates to cover all branches of LossScaler.update(train_step_info.all_finite=True) - loss_scale = float(1 << 16) - for cycles in range(1, 10): - # 1999 updates without overflow produces 1999 stable steps - for i in range(1, 2000): - new_loss_scale = default_scaler.update(train_step_info) - assert default_scaler._stable_steps_count == i - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg=f"loss scale mismatch at update {i}") - - # 2000th update without overflow doubles the loss and zero stable steps until max_loss_scale is reached - new_loss_scale = default_scaler.update(train_step_info) - if cycles <= 8: - loss_scale *= 2 - assert default_scaler._stable_steps_count == 0 - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") - - # After 8 cycles, loss scale should be float(1 << 16)*(2**8) - assert_allclose(new_loss_scale, float(1 << 16) * (2**8), rtol=rtol, err_msg="loss scale mismatch") - - # After 9 cycles, loss scale reaches max_loss_scale and it is not doubled from that point on - loss_scale = float(1 << 16) * (2**8) - for count in range(1, 2050): - new_loss_scale = default_scaler.update(train_step_info) - assert default_scaler._stable_steps_count == (count % 2000) - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") - - # Setting train_step_info.all_finite = False to test down scaling - train_step_info.all_finite = False - - # Performing 24 updates to half the loss scale each time - loss_scale = float(1 << 16) * (2**8) - for count in range(1, 25): # noqa: B007 - new_loss_scale = default_scaler.update(train_step_info) - loss_scale /= 2 - assert default_scaler._stable_steps_count == 0 - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") - - # After 24 updates with gradient overflow, loss scale is 1.0 - assert_allclose(new_loss_scale, 1.0, rtol=rtol, err_msg="loss scale mismatch") - - # After 25 updates, min_loss_scale is reached and loss scale is not halfed from that point on - for count in range(1, 5): # noqa: B007 - new_loss_scale = default_scaler.update(train_step_info) - assert default_scaler._stable_steps_count == 0 - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") - - -def testDynamicLossScalerCustomValues(): - rtol = 1e-7 - scaler = amp.loss_scaler.DynamicLossScaler( - automatic_update=False, loss_scale=3, up_scale_window=7, min_loss_scale=5, max_loss_scale=10 - ) - assert scaler.automatic_update is False - assert_allclose(scaler.loss_scale, 3, rtol=rtol, err_msg="loss scale mismatch") - assert_allclose(scaler.min_loss_scale, 5, rtol=rtol, err_msg="min loss scale mismatch") - assert_allclose(scaler.max_loss_scale, 10, rtol=rtol, err_msg="max loss scale mismatch") - assert scaler.up_scale_window == 7 - - -def testTrainStepInfo(): - """Test valid initializations of TrainStepInfo""" - - optimizer_config = optim.LambConfig() - fetches = ["out1", "out2"] - step_info = orttrainer.TrainStepInfo( - optimizer_config=optimizer_config, all_finite=False, fetches=fetches, optimization_step=123, step=456 - ) - assert step_info.optimizer_config == optimizer_config - assert step_info.all_finite is False - assert step_info.fetches == fetches - assert step_info.optimization_step == 123 - assert step_info.step == 456 - - step_info = orttrainer.TrainStepInfo(optimizer_config) - assert step_info.optimizer_config == optimizer_config - assert step_info.all_finite is True - assert step_info.fetches == [] - assert step_info.optimization_step == 0 - assert step_info.step == 0 - - -@pytest.mark.parametrize( - "invalid_input", - [ - (-1), - ("Hello"), - ], -) -def testTrainStepInfoInvalidInput(invalid_input): - """Test invalid initialization of TrainStepInfo""" - optimizer_config = optim.LambConfig() - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config=invalid_input) - - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config, all_finite=invalid_input) - - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config, fetches=invalid_input) - - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config, optimization_step=invalid_input) - - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config, step=invalid_input) - - -@pytest.mark.parametrize( - "optim_name,lr,alpha,default_alpha", - [ - ("AdamOptimizer", 0.1, 0.2, None), - ("LambOptimizer", 0.2, 0.3, None), - ("SGDOptimizer", 0.3, 0.4, None), - ("SGDOptimizer", 0.3, 0.4, 0.5), - ], -) -def testOptimizerConfig(optim_name, lr, alpha, default_alpha): - """Test initialization of _OptimizerConfig""" - defaults = {"lr": lr, "alpha": alpha} - params = [{"params": ["fc1.weight", "fc2.weight"]}] - if default_alpha is not None: - params[0].update({"alpha": default_alpha}) - else: - params[0].update({"alpha": alpha}) - cfg = optim.config._OptimizerConfig(name=optim_name, params=params, defaults=defaults) - - assert cfg.name == optim_name - rtol = 1e-07 - assert_allclose(defaults["lr"], cfg.lr, rtol=rtol, err_msg="lr mismatch") - - # 1:1 mapping between defaults and params's hyper parameters - for param in params: - for k in param: - if k != "params": - assert k in cfg.defaults, "hyper parameter {k} not present in one of the parameter params" - for k in cfg.defaults: - for param in cfg.params: - assert k in param, "hyper parameter {k} not present in one of the parameter params" - - -@pytest.mark.parametrize( - "optim_name,defaults,params", - [ - ("AdamOptimizer", {"lr": -1}, []), # invalid lr - ("FooOptimizer", {"lr": 0.001}, []), # invalid name - ("SGDOptimizer", [], []), # invalid type(defaults) - (optim.AdamConfig, {"lr": 0.003}, []), # invalid type(name) - ("AdamOptimizer", {"lr": None}, []), # missing 'lr' hyper parameter - ("SGDOptimizer", {"lr": 0.004}, {}), # invalid type(params) - # invalid type(params[i]) - ("AdamOptimizer", {"lr": 0.005, "alpha": 2}, [[]]), - # missing 'params' at 'params' - ("AdamOptimizer", {"lr": 0.005, "alpha": 2}, [{"alpha": 1}]), - # missing 'alpha' at 'defaults' - ("AdamOptimizer", {"lr": 0.005}, [{"params": "param1", "alpha": 1}]), - ], -) -def testOptimizerConfigInvalidInputs(optim_name, defaults, params): - """Test invalid initialization of _OptimizerConfig""" - - with pytest.raises(AssertionError): - optim.config._OptimizerConfig(name=optim_name, params=params, defaults=defaults) - - -def testOptimizerConfigSGD(): - """Test initialization of SGD""" - cfg = optim.SGDConfig() - assert cfg.name == "SGDOptimizer" - - rtol = 1e-07 - assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") - - cfg = optim.SGDConfig(lr=0.002) - assert_allclose(0.002, cfg.lr, rtol=rtol, err_msg="lr mismatch") - - # SGD does not support params - with pytest.raises(AssertionError) as e: - params = [{"params": ["layer1.weight"], "lr": 0.1}] - optim.SGDConfig(params=params, lr=0.002) - assert_allclose(0.002, cfg.lr, rtol=rtol, err_msg="lr mismatch") - assert str(e.value) == "'params' must be an empty list for SGD optimizer" - - -def testOptimizerConfigAdam(): - """Test initialization of Adam""" - cfg = optim.AdamConfig() - assert cfg.name == "AdamOptimizer" - - rtol = 1e-7 - assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") - assert_allclose(0.9, cfg.alpha, rtol=rtol, err_msg="alpha mismatch") - assert_allclose(0.999, cfg.beta, rtol=rtol, err_msg="beta mismatch") - assert_allclose(0.0, cfg.lambda_coef, rtol=rtol, err_msg="lambda_coef mismatch") - assert_allclose(1e-8, cfg.epsilon, rtol=rtol, err_msg="epsilon mismatch") - assert_allclose(1.0, cfg.max_norm_clip, rtol=rtol, err_msg="max_norm_clip mismatch") - assert cfg.do_bias_correction is True, "lambda_coef mismatch" - assert cfg.weight_decay_mode == optim.AdamConfig.DecayMode.BEFORE_WEIGHT_UPDATE, "weight_decay_mode mismatch" - - -def testOptimizerConfigLamb(): - """Test initialization of Lamb""" - cfg = optim.LambConfig() - assert cfg.name == "LambOptimizer" - rtol = 1e-7 - assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") - assert_allclose(0.9, cfg.alpha, rtol=rtol, err_msg="alpha mismatch") - assert_allclose(0.999, cfg.beta, rtol=rtol, err_msg="beta mismatch") - assert_allclose(0.0, cfg.lambda_coef, rtol=rtol, err_msg="lambda_coef mismatch") - assert cfg.ratio_min == float("-inf"), "ratio_min mismatch" - assert cfg.ratio_max == float("inf"), "ratio_max mismatch" - assert_allclose(1e-6, cfg.epsilon, rtol=rtol, err_msg="epsilon mismatch") - assert_allclose(1.0, cfg.max_norm_clip, rtol=rtol, err_msg="max_norm_clip mismatch") - assert cfg.do_bias_correction is False, "do_bias_correction mismatch" - - -@pytest.mark.parametrize("optim_name", [("Adam"), ("Lamb")]) -def testOptimizerConfigParams(optim_name): - rtol = 1e-7 - params = [{"params": ["layer1.weight"], "alpha": 0.1}] - if optim_name == "Adam": - cfg = optim.AdamConfig(params=params, alpha=0.2) - elif optim_name == "Lamb": - cfg = optim.LambConfig(params=params, alpha=0.2) - else: - raise ValueError("invalid input") - assert len(cfg.params) == 1, "params should have length 1" - assert_allclose(cfg.params[0]["alpha"], 0.1, rtol=rtol, err_msg="invalid lr on params[0]") - - -@pytest.mark.parametrize("optim_name", [("Adam"), ("Lamb")]) -def testOptimizerConfigInvalidParams(optim_name): - # lr is not supported within params - with pytest.raises(AssertionError) as e: - params = [{"params": ["layer1.weight"], "lr": 0.1}] - if optim_name == "Adam": - optim.AdamConfig(params=params, lr=0.2) - elif optim_name == "Lamb": - optim.LambConfig(params=params, lr=0.2) - else: - raise ValueError("invalid input") - assert str(e.value) == "'lr' is not supported inside params" - - -def testLinearLRSchedulerCreation(): - total_steps = 10 - warmup = 0.05 - - lr_scheduler = optim.lr_scheduler.LinearWarmupLRScheduler(total_steps, warmup) - - # Initial state - assert lr_scheduler.total_steps == total_steps - assert lr_scheduler.warmup == warmup - - -@pytest.mark.parametrize( - "lr_scheduler,expected_values", - [ - (optim.lr_scheduler.ConstantWarmupLRScheduler, [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0]), - ( - optim.lr_scheduler.CosineWarmupLRScheduler, - [ - 0.0, - 0.9763960957919413, - 0.9059835861602854, - 0.7956724530494887, - 0.6563036824392345, - 0.5015739416158049, - 0.34668951940611276, - 0.2068719061737831, - 0.09586187986225325, - 0.0245691111902418, - ], - ), - (optim.lr_scheduler.LinearWarmupLRScheduler, [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.8, 0.6, 0.4, 0.2]), - ( - optim.lr_scheduler.PolyWarmupLRScheduler, - [ - 0.0, - 0.9509018036072144, - 0.9008016032064128, - 0.8507014028056112, - 0.8006012024048097, - 0.750501002004008, - 0.7004008016032064, - 0.6503006012024048, - 0.6002004008016032, - 0.5501002004008015, - ], - ), - ], -) -def testLRSchedulerUpdateImpl(lr_scheduler, expected_values): - # Test tolerance - rtol = 1e-03 - - # Initial state - initial_lr = 1 - total_steps = 10 - warmup = 0.5 - optimizer_config = optim.SGDConfig(lr=initial_lr) - lr_scheduler = lr_scheduler(total_steps, warmup) - - # First half is warmup - for optimization_step in range(total_steps): - # Emulate ORTTRainer.train_step() call that updates its train_step_info - train_step_info = TrainStepInfo(optimizer_config=optimizer_config, optimization_step=optimization_step) - - lr_scheduler._step(train_step_info) - lr_list = lr_scheduler.get_last_lr() - assert len(lr_list) == 1 - assert_allclose(lr_list[0], expected_values[optimization_step], rtol=rtol, err_msg="lr mismatch") - - -def testInstantiateORTTrainerOptions(): - session_options = SessionOptions() - session_options.enable_mem_pattern = False - provider_options = {"EP1": {"key": "val"}} - opts = {"session_options": session_options, "provider_options": provider_options} - opts = orttrainer.ORTTrainerOptions(opts) - assert opts.session_options.enable_mem_pattern is False - assert opts._validated_opts["provider_options"]["EP1"]["key"] == "val" - - -@pytest.mark.parametrize( - "step_fn, lr_scheduler, expected_lr_values, device", - [ - ("train_step", None, None, "cuda"), - ("eval_step", None, None, "cpu"), - ( - "train_step", - optim.lr_scheduler.ConstantWarmupLRScheduler, - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0], - "cpu", - ), - ( - "train_step", - optim.lr_scheduler.CosineWarmupLRScheduler, - [ - 0.0, - 0.2, - 0.4, - 0.6, - 0.8, - 1.0, - 0.9045084971874737, - 0.6545084971874737, - 0.34549150281252633, - 0.09549150281252633, - ], - "cuda", - ), - ( - "train_step", - optim.lr_scheduler.LinearWarmupLRScheduler, - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.8, 0.6, 0.4, 0.2], - "cpu", - ), - ( - "train_step", - optim.lr_scheduler.PolyWarmupLRScheduler, - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.80000002, 0.60000004, 0.40000006000000005, 0.20000007999999997], - "cuda", - ), - ], -) -def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values, device): - total_steps = 1 - initial_lr = 1.0 - rtol = 1e-3 - - # PyTorch Transformer model as example - opts = {"device": {"id": device}} - if lr_scheduler: - total_steps = 10 - opts.update({"lr_scheduler": lr_scheduler(total_steps=total_steps, warmup=0.5)}) - opts = orttrainer.ORTTrainerOptions(opts) - optim_config = optim.LambConfig(lr=initial_lr) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _test_commons._load_pytorch_transformer_model( - device - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - - # Run a train or evaluation step - if step_fn == "eval_step": - data, targets = batcher_fn(val_data, 0) - elif step_fn == "train_step": - data, targets = batcher_fn(train_data, 0) - else: - raise ValueError("Invalid step_fn") - - # Export model to ONNX - if step_fn == "eval_step": - step_fn = trainer.eval_step - output = trainer.eval_step(data, targets) - elif step_fn == "train_step": - step_fn = trainer.train_step - for i in range(total_steps): - output = trainer.train_step(data, targets) - if lr_scheduler: - lr_list = trainer.options.lr_scheduler.get_last_lr() - assert_allclose(lr_list[0], expected_lr_values[i], rtol=rtol, err_msg="lr mismatch") - else: - raise ValueError("Invalid step_fn") - assert trainer._onnx_model is not None - - # Check output shape after train/eval step - for out, desc in zip(output, trainer.model_desc.outputs): - if trainer.loss_fn and desc.is_loss: - continue - assert list(out.size()) == desc.shape - - # Check name, shape and dtype of the first len(forward.parameters) ORT graph inputs - sig = inspect.signature(model.forward) - for i in range(len(sig.parameters.keys())): - input_name = trainer.model_desc.inputs[i][0] - input_dim = trainer.model_desc.inputs[i][1] - input_type = trainer.model_desc.inputs[i][2] - - assert trainer._onnx_model.graph.input[i].name == input_name - for dim_idx, dim in enumerate(trainer._onnx_model.graph.input[i].type.tensor_type.shape.dim): - assert input_dim[dim_idx] == dim.dim_value - assert input_type == _utils.dtype_onnx_to_torch( - trainer._onnx_model.graph.input[i].type.tensor_type.elem_type - ) - - opset = get_model_opset(trainer._onnx_model) - - # Check name, shape and dtype of the ORT graph outputs - for i in range(len(trainer.model_desc.outputs)): - output_name = trainer.model_desc.outputs[i][0] - output_dim = trainer.model_desc.outputs[i][1] - output_type = trainer.model_desc.outputs[i][3] - - assert trainer._onnx_model.graph.output[i].name == output_name - for dim_idx, dim in enumerate(trainer._onnx_model.graph.output[i].type.tensor_type.shape.dim): - if opset is None or opset <= 12: - assert output_dim[dim_idx] == dim.dim_value - assert output_type == _utils.dtype_onnx_to_torch( - trainer._onnx_model.graph.output[i].type.tensor_type.elem_type - ) - - # Save current model as ONNX as a file - file_name = os.path.join("_____temp_onnx_model.onnx") - trainer.save_as_onnx(file_name) - assert os.path.exists(file_name) - with open(file_name, "rb") as f: - bin_str = f.read() - reload_onnx_model = onnx.load_model_from_string(bin_str) - os.remove(file_name) - - # Create a new trainer from persisted ONNX model and compare with original ONNX model - trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config) - step_fn(data, targets) - assert trainer_from_onnx._onnx_model is not None - assert id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model) - assert trainer_from_onnx._onnx_model == trainer._onnx_model - assert trainer_from_onnx._onnx_model.graph == trainer._onnx_model.graph - assert onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph( - trainer._onnx_model.graph - ) - - -@pytest.mark.parametrize("seed, device", [(0, "cpu"), (24, "cuda")]) -def testORTDeterministicCompute(seed, device): - # Common setup - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - {"debug": {"deterministic_compute": True}, "device": {"id": device, "mem_limit": 10 * 1024 * 1024}} - ) - - # Setup for the first ORTTRainer run - torch.manual_seed(seed) - set_seed(seed) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - data, targets = batcher_fn(train_data, 0) - _ = first_trainer.train_step(data, targets) - assert first_trainer._onnx_model is not None - - # Setup for the second ORTTRainer run - torch.manual_seed(seed) - set_seed(seed) - model, _, _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device) - second_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - _ = second_trainer.train_step(data, targets) - assert second_trainer._onnx_model is not None - - # Compare two different instances with identical setup - assert id(first_trainer._onnx_model) != id(second_trainer._onnx_model) - _test_helpers.assert_onnx_weights(first_trainer, second_trainer) - - -@pytest.mark.parametrize( - "seed,device,expected_loss,fetches", - [ - (321, "cuda", [10.5774, 10.4403, 10.4175, 10.2886, 10.2760], False), - (321, "cuda", [10.5774, 10.4403, 10.4175, 10.2886, 10.2760], True), - ], -) -def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches): - return # TODO: re-enable after nondeterminism on backend is fixed. update numbers - - rtol = 1e-3 - total_steps = len(expected_loss) - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - loss_scaler = amp.DynamicLossScaler() - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _test_commons._load_pytorch_transformer_model( - device - ) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - if fetches: - trainer._train_step_info.fetches = ["loss"] - loss = trainer.train_step(data, targets) - else: - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu()) - - # Eval once just to test fetches in action - val_data, val_targets = batcher_fn(val_data, 0) - if fetches: - trainer._train_step_info.fetches = ["loss"] - loss = trainer.eval_step(val_data, val_targets) - trainer._train_step_info.fetches = [] - loss, _ = trainer.eval_step(val_data, val_targets) - - # Compare loss to ground truth computed from current ORTTrainer API - _test_helpers.assert_model_outputs(expected_loss, actual_loss, True, rtol=rtol) - assert trainer._onnx_model is not None - - -def _recompute_data(): - device_capability_major = torch.cuda.get_device_capability()[0] - if device_capability_major == 7: # V100 for Dev machine - expected_loss = { - 12: [10.5598, 10.4591, 10.3477, 10.2726, 10.1945], - 14: [10.54088, 10.498755, 10.386827, 10.338747, 10.262459], - } - return [ - (False, False, False, 0, expected_loss), # no recompute - (True, False, False, 0, expected_loss), # attn_dropout recompute - (False, True, False, 0, expected_loss), # gelu recompute - (False, False, True, 0, expected_loss), # transformer_layer recompute - (False, False, True, 1, expected_loss), # transformer_layer recompute with 1 layer - ] - elif device_capability_major == 5: # M60 for CI machines - expected_loss = { - 12: [10.5445, 10.4389, 10.3480, 10.2627, 10.2113], - 14: [10.5445, 10.4389, 10.3480, 10.2627, 10.2113], - } - return [ - (False, False, False, 0, expected_loss), # no recompute - (True, False, False, 0, expected_loss), # attn_dropout recompute - (False, True, False, 0, expected_loss), # gelu recompute - (False, False, True, 0, expected_loss), # transformer_layer recompute - (False, False, True, 1, expected_loss), # transformer_layer recompute with 1 layer - ] - - -@pytest.mark.parametrize("attn_dropout, gelu, transformer_layer, number_layers, expected_loss", _recompute_data()) -def testORTTrainerRecompute(attn_dropout, gelu, transformer_layer, number_layers, expected_loss): - seed = 321 - device = "cuda" - rtol = 1e-3 - total_steps = len(expected_loss[12]) - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "graph_transformer": { - "attn_dropout_recompute": attn_dropout, - "gelu_recompute": gelu, - "transformer_layer_recompute": transformer_layer, - "number_recompute_layers": number_layers, - }, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _test_commons._load_pytorch_transformer_model( - device - ) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu()) - - # Compare loss to ground truth computed from current ORTTrainer API - assert trainer._onnx_model is not None - opset = get_model_opset(trainer._onnx_model) - _test_helpers.assert_model_outputs(expected_loss[opset], actual_loss, True, rtol=rtol) - - -@pytest.mark.parametrize( - "seed,device,gradient_accumulation_steps,total_steps,expected_loss", - [ - ( - 0, - "cuda", - 1, - 12, - [ - 10.5368022919, - 10.4146203995, - 10.3635568619, - 10.2650547028, - 10.2284049988, - 10.1304626465, - 10.0853414536, - 9.9987659454, - 9.9472427368, - 9.8832416534, - 9.8223171234, - 9.8222122192, - ], - ), - ( - 42, - "cuda", - 3, - 12, - [ - 10.6455879211, - 10.6247081757, - 10.6361322403, - 10.5187482834, - 10.5345087051, - 10.5487670898, - 10.4833698273, - 10.4600019455, - 10.4535751343, - 10.3774127960, - 10.4144191742, - 10.3757553101, - ], - ), - ( - 123, - "cuda", - 7, - 12, - [ - 10.5353469849, - 10.5261383057, - 10.5240392685, - 10.5013713837, - 10.5678377151, - 10.5452117920, - 10.5184345245, - 10.4271221161, - 10.4458627701, - 10.4864749908, - 10.4416503906, - 10.4467563629, - ], - ), - ( - 321, - "cuda", - 12, - 12, - [ - 10.5773944855, - 10.5428829193, - 10.5974750519, - 10.5416746140, - 10.6009902954, - 10.5684127808, - 10.5759754181, - 10.5636739731, - 10.5613927841, - 10.5825119019, - 10.6031589508, - 10.6199369431, - ], - ), - ], -) -def testORTTrainerGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps, expected_loss): - return # TODO: re-enable after nondeterminism on backend is fixed. update numbers - rtol = 1e-3 - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu()) - - # Compare legacy vs experimental APIs - _test_helpers.assert_model_outputs(expected_loss, actual_loss, rtol=rtol) - - -@pytest.mark.parametrize( - "dynamic_axes", - [ - (True), - (False), - ], -) -def testORTTrainerDynamicShape(dynamic_axes): - # Common setup - device = "cuda" - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions({}) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model( - device, dynamic_axes=dynamic_axes - ) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - total_steps = 10 - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - if dynamic_axes: - # Forcing batches with different sizes to exercise dynamic shapes - data = data[: -(i + 1)] - targets = targets[: -(i + 1) * data.size(1)] - _, _ = trainer.train_step(data, targets) - - assert trainer._onnx_model is not None - - -@pytest.mark.parametrize( - "enable_onnx_contrib_ops", - [ - (True), - (False), - ], -) -def testORTTrainerInternalUseContribOps(enable_onnx_contrib_ops): - # Common setup - device = "cuda" - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions({"_internal_use": {"enable_onnx_contrib_ops": enable_onnx_contrib_ops}}) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - data, targets = batcher_fn(train_data, 0) - if not enable_onnx_contrib_ops and not pytorch_110: - with pytest.raises(Exception): # noqa: B017 - _, _ = trainer.train_step(data, targets) - else: - _, _ = trainer.train_step(data, targets) - - -@pytest.mark.parametrize( - "model_params", - [ - ( - [ - "decoder.weight", - "transformer_encoder.layers.0.linear1.bias", - "transformer_encoder.layers.0.linear2.weight", - "transformer_encoder.layers.1.self_attn.out_proj.weight", - "transformer_encoder.layers.1.self_attn.out_proj.bias", - ] - ), - ], -) -def testORTTrainerFrozenWeights(model_params): - # Common setup - device = "cuda" - total_steps = 10 - - # Setup ORTTrainer WITHOUT frozen weights - options = orttrainer.ORTTrainerOptions({}) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - _, _ = trainer.train_step(data, targets) - - # All model_params must be in the session state - assert trainer._onnx_model is not None - session_state = trainer._training_session.get_state() - assert all([param in session_state for param in model_params]) - - # Setup ORTTrainer WITH frozen weights - options = orttrainer.ORTTrainerOptions({"utils": {"frozen_weights": model_params}}) - model, _, _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - _, _ = trainer.train_step(data, targets) - - # All model_params CANNOT be in the session state - assert trainer._onnx_model is not None - session_state = trainer._training_session.get_state() - assert not all([param in session_state for param in model_params]) - - -@pytest.mark.parametrize( - "loss_scaler, optimizer_config, gradient_accumulation_steps", - [ - (None, optim.AdamConfig(), 1), - (None, optim.LambConfig(), 1), - (None, optim.SGDConfig(), 1), - (amp.DynamicLossScaler(), optim.AdamConfig(), 1), - (amp.DynamicLossScaler(), optim.LambConfig(), 5), - # (amp.DynamicLossScaler(), optim.SGDConfig(), 1), # SGD doesnt support fp16 - ], -) -def testORTTrainerStateDictWrapModelLossFn(loss_scaler, optimizer_config, gradient_accumulation_steps): - # Common setup - seed = 1 - - class LinearModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 4) - - def forward(self, y=None, x=None): - if y is not None: - return self.linear(x) + y - else: - return self.linear(x) + torch.ones(2, 4) - - model_desc = { - "inputs": [ - ("x", [2, 2]), - ( - "label", - [ - 2, - ], - ), - ], - "outputs": [("loss", [], True), ("output", [2, 4])], - } - - # Dummy data - data1 = torch.randn(2, 2) - label1 = torch.tensor([0, 1], dtype=torch.int64) - data2 = torch.randn(2, 2) - label2 = torch.tensor([0, 1], dtype=torch.int64) - - # Setup training based on test parameters - opts = { - "debug": {"deterministic_compute": True}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - } - if loss_scaler: - opts["mixed_precision"] = {"enabled": True, "loss_scaler": loss_scaler} - opts = orttrainer.ORTTrainerOptions(opts) - - # Training session 1 - torch.manual_seed(seed) - set_seed(seed) - pt_model = LinearModel() - - def loss_fn(x, label): - return F.nll_loss(F.log_softmax(x, dim=1), label) - - trainer = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts) - - # Check state_dict keys before train. Must be empty - state_dict = trainer.state_dict() - assert state_dict == {} - - # Train once and check initial state - trainer.train_step(x=data1, label=label1) - state_dict = trainer.state_dict() - assert all([weight in state_dict["model"]["full_precision"] for weight in ["linear.bias", "linear.weight"]]) - - # Initialize training session 2 from state of Training 1 - torch.manual_seed(seed) - set_seed(seed) - trainer2 = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts) - trainer2.load_state_dict(state_dict) - - # Verify state was loaded properly - _test_commons.assert_all_states_close_ort(state_dict, trainer2._load_state_dict.args[0]) - - # Perform a second step in both training session 1 and 2 and verify they match - trainer.train_step(x=data2, label=label2) - state_dict = trainer.state_dict() - trainer2.train_step(x=data2, label=label2) - state_dict2 = trainer2.state_dict() - _test_commons.assert_all_states_close_ort(state_dict, state_dict2) - - -def testORTTrainerNonPickableModel(): - # Common setup - import threading - - seed = 1 - - class UnpickableModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 4) - self._lock = threading.Lock() - - def forward(self, y=None, x=None): - with self._lock: - if y is not None: - return self.linear(x) + y - else: - return self.linear(x) + torch.ones(2, 4) - - model_desc = { - "inputs": [ - ("x", [2, 2]), - ( - "label", - [ - 2, - ], - ), - ], - "outputs": [("loss", [], True), ("output", [2, 4])], - } - - # Dummy data - data = torch.randn(2, 2) - label = torch.tensor([0, 1], dtype=torch.int64) - - # Setup training based on test parameters - opts = orttrainer.ORTTrainerOptions({"debug": {"deterministic_compute": True}}) - - # Training session - torch.manual_seed(seed) - set_seed(seed) - pt_model = UnpickableModel() - - def loss_fn(x, label): - return F.nll_loss(F.log_softmax(x, dim=1), label) - - optim_config = optim.AdamConfig() - trainer = orttrainer.ORTTrainer(pt_model, model_desc, optim_config, loss_fn=loss_fn, options=opts) - - # Train must succeed despite warning - _, _ = trainer.train_step(data, label) - - -############################################################################### -# Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############ -############################################################################### - - -@pytest.mark.parametrize("seed,device", [(1234, "cuda")]) -def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device): - # Common data - rtol = 1e-7 - total_steps = 5 - - # Setup for the experimental ORTTRainer run - torch.manual_seed(seed) - set_seed(seed) - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - # Training loop - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - _ = trainer.train_step(data, targets) - - # Setup for the legacy ORTTrainer run - torch.manual_seed(seed) - set_seed(seed) - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer( - model, my_loss, model_desc, "LambOptimizer", None, lr_desc, device, _use_deterministic_compute=True - ) - # Training loop - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - _, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) - - # Compare legacy vs experimental APIs - _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=rtol) - - -@pytest.mark.parametrize( - "seed,device", - [ - (321, "cuda"), - ], -) -def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device): - # Common data - total_steps = 128 - - # Setup experimental API - torch.manual_seed(seed) - set_seed(seed) - loss_scaler = amp.DynamicLossScaler() - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, - "debug": { - "deterministic_compute": True, - }, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - experimental_loss = [] - experimental_preds_dtype = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - exp_loss, exp_preds = trainer.train_step(data, targets) - experimental_loss.append(exp_loss.cpu()) - experimental_preds_dtype.append(exp_preds.dtype) - - # Setup legacy API - torch.manual_seed(seed) - set_seed(seed) - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - loss_scaler = Legacy_LossScaler("ort_test_input_loss_scalar", True) - legacy_trainer = Legacy_ORTTrainer( - model, - my_loss, - model_desc, - "LambOptimizer", - None, - lr_desc, - device=device, - _use_deterministic_compute=True, - use_mixed_precision=True, - loss_scaler=loss_scaler, - ) - # Training loop - legacy_loss = [] - legacy_preds_dtype = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) - legacy_loss.append(leg_loss.cpu()) - legacy_preds_dtype.append(leg_preds.dtype) - - # Compare legacy vs experimental APIs - assert experimental_preds_dtype == legacy_preds_dtype - _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer) - _test_helpers.assert_model_outputs(legacy_loss, experimental_loss) - - -@pytest.mark.parametrize( - "seed,device,gradient_accumulation_steps,total_steps", - [ - (0, "cuda", 1, 12), - (42, "cuda", 3, 12), - (123, "cuda", 7, 12), - (321, "cuda", 12, 12), - ], -) -def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps): - # Common data - torch.set_printoptions(precision=10) - - # Setup experimental API - torch.manual_seed(seed) - set_seed(seed) - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - experimental_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - exp_loss, _ = trainer.train_step(data, targets) - experimental_loss.append(exp_loss.cpu()) - - # Setup legacy API - torch.manual_seed(seed) - set_seed(seed) - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer( - model, - my_loss, - model_desc, - "LambOptimizer", - None, - lr_desc, - device=device, - _use_deterministic_compute=True, - gradient_accumulation_steps=gradient_accumulation_steps, - ) - # Training loop - legacy_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - leg_loss, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) - legacy_loss.append(leg_loss.cpu()) - - # Compare legacy vs experimental APIs - _test_helpers.assert_model_outputs(legacy_loss, experimental_loss) - - -@pytest.mark.parametrize( - "seed,device,optimizer_config,lr_scheduler, get_lr_this_step", - [ - ( - 0, - "cuda", - optim.AdamConfig, - optim.lr_scheduler.ConstantWarmupLRScheduler, - _test_commons.legacy_constant_lr_scheduler, - ), - ( - 0, - "cuda", - optim.LambConfig, - optim.lr_scheduler.ConstantWarmupLRScheduler, - _test_commons.legacy_constant_lr_scheduler, - ), - ( - 0, - "cuda", - optim.SGDConfig, - optim.lr_scheduler.ConstantWarmupLRScheduler, - _test_commons.legacy_constant_lr_scheduler, - ), - ( - 42, - "cuda", - optim.AdamConfig, - optim.lr_scheduler.LinearWarmupLRScheduler, - _test_commons.legacy_linear_lr_scheduler, - ), - ( - 42, - "cuda", - optim.LambConfig, - optim.lr_scheduler.LinearWarmupLRScheduler, - _test_commons.legacy_linear_lr_scheduler, - ), - ( - 42, - "cuda", - optim.SGDConfig, - optim.lr_scheduler.LinearWarmupLRScheduler, - _test_commons.legacy_linear_lr_scheduler, - ), - ( - 123, - "cuda", - optim.AdamConfig, - optim.lr_scheduler.CosineWarmupLRScheduler, - _test_commons.legacy_cosine_lr_scheduler, - ), - ( - 123, - "cuda", - optim.LambConfig, - optim.lr_scheduler.CosineWarmupLRScheduler, - _test_commons.legacy_cosine_lr_scheduler, - ), - ( - 123, - "cuda", - optim.SGDConfig, - optim.lr_scheduler.CosineWarmupLRScheduler, - _test_commons.legacy_cosine_lr_scheduler, - ), - ( - 321, - "cuda", - optim.AdamConfig, - optim.lr_scheduler.PolyWarmupLRScheduler, - _test_commons.legacy_poly_lr_scheduler, - ), - ( - 321, - "cuda", - optim.LambConfig, - optim.lr_scheduler.PolyWarmupLRScheduler, - _test_commons.legacy_poly_lr_scheduler, - ), - ( - 321, - "cuda", - optim.SGDConfig, - optim.lr_scheduler.PolyWarmupLRScheduler, - _test_commons.legacy_poly_lr_scheduler, - ), - ], -) -def testORTTrainerLegacyAndExperimentalLRScheduler(seed, device, optimizer_config, lr_scheduler, get_lr_this_step): - # Common data - total_steps = 10 - lr = 0.001 - warmup = 0.5 - cycles = 0.5 - power = 1.0 - lr_end = 1e-7 - torch.set_printoptions(precision=10) - - # Setup experimental API - torch.manual_seed(seed) - set_seed(seed) - if ( - lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler - or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler - ): - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) - elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) - elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) - else: - raise RuntimeError("Invalid lr_scheduler") - - options = orttrainer.ORTTrainerOptions( - {"device": {"id": device}, "debug": {"deterministic_compute": True}, "lr_scheduler": lr_scheduler} - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optimizer_config(lr=lr) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - experimental_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - exp_loss, exp_preds = trainer.train_step(data, targets) - experimental_loss.append(exp_loss.cpu()) - - # Setup legacy API - torch.manual_seed(seed) - set_seed(seed) - - if optimizer_config == optim.AdamConfig: - legacy_optimizer_config = "AdamOptimizer" - elif optimizer_config == optim.LambConfig: - legacy_optimizer_config = "LambOptimizer" - elif optimizer_config == optim.SGDConfig: - legacy_optimizer_config = "SGDOptimizer" - else: - raise RuntimeError("Invalid optimizer_config") - - if ( - get_lr_this_step == _test_commons.legacy_constant_lr_scheduler - or get_lr_this_step == _test_commons.legacy_linear_lr_scheduler - ): - get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup) - elif get_lr_this_step == _test_commons.legacy_cosine_lr_scheduler: - get_lr_this_step = partial( - get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, cycles=cycles - ) - elif get_lr_this_step == _test_commons.legacy_poly_lr_scheduler: - get_lr_this_step = partial( - get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end - ) - else: - raise RuntimeError("Invalid get_lr_this_step") - - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer( - model, - my_loss, - model_desc, - legacy_optimizer_config, - None, - lr_desc, - device=device, - _use_deterministic_compute=True, - get_lr_this_step=get_lr_this_step, - ) - # Training loop - legacy_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - leg_loss, leg_preds = legacy_trainer.train_step(data, targets) - legacy_loss.append(leg_loss.cpu()) - - # Compare legacy vs experimental APIs - _test_helpers.assert_model_outputs(legacy_loss, experimental_loss) - - -def testLossScalerLegacyAndExperimentalFullCycle(): - orttrainer.TrainStepInfo( - optimizer_config=optim.LambConfig(lr=0.001), all_finite=True, fetches=[], optimization_step=0, step=0 - ) - new_ls = amp.DynamicLossScaler() - old_ls = Legacy_LossScaler("ort_test_input_loss_scaler", True) - - # Initial state - train_step_info = orttrainer.TrainStepInfo(optim.LambConfig()) - assert_allclose(new_ls.loss_scale, old_ls.loss_scale_) - assert new_ls.up_scale_window == old_ls.up_scale_window_ - assert_allclose(new_ls.min_loss_scale, old_ls.min_loss_scale_) - assert_allclose(new_ls.max_loss_scale, old_ls.max_loss_scale_) - - # Performing 9*2000 updates to cover all branches of LossScaler.update(train_step_info.all_finite=True) - for _cycles in range(1, 10): - # 1999 updates without overflow produces 1999 stable steps - for _i in range(1, 2000): - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - # 2000th update without overflow doubles the loss and zero stable steps until max_loss_scale is reached - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - # After 8 cycles, loss scale should be float(1 << 16)*(2**8) - assert_allclose(new_loss_scale, old_loss_scale) - - # After 9 cycles, loss scale reaches max_loss_scale and it is not doubled from that point on - for _count in range(1, 2050): - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - # Setting train_step_info.all_finite = False to test down scaling - train_step_info.all_finite = False - - # Performing 24 updates to half the loss scale each time - for _count in range(1, 25): - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - # After 24 updates with gradient overflow, loss scale is 1.0 - assert_allclose(new_loss_scale, old_loss_scale) - - # After 25 updates, min_loss_scale is reached and loss scale is not halfed from that point on - for _count in range(1, 5): - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - -def testLossScalerLegacyAndExperimentalRandomAllFinite(): - new_ls = amp.DynamicLossScaler() - old_ls = Legacy_LossScaler("ort_test_input_loss_scaler", True) - - # Initial state - train_step_info = orttrainer.TrainStepInfo(optim.LambConfig()) - assert_allclose(new_ls.loss_scale, old_ls.loss_scale_) - assert new_ls.up_scale_window == old_ls.up_scale_window_ - assert_allclose(new_ls.min_loss_scale, old_ls.min_loss_scale_) - assert_allclose(new_ls.max_loss_scale, old_ls.max_loss_scale_) - - import random - - out = [] - for _ in range(1, 64): - train_step_info.all_finite = bool(random.getrandbits(1)) - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - out.append(new_loss_scale) - assert new_loss_scale > 1e-7 - - -def testORTTrainerRunSymbolicShapeInfer(): - # Common data - seed = 0 - total_steps = 12 - device = "cuda" - torch.set_printoptions(precision=10) - - # Setup without symbolic shape inference - torch.manual_seed(seed) - set_seed(seed) - options = orttrainer.ORTTrainerOptions({"device": {"id": device}, "debug": {"deterministic_compute": True}}) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - expected_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - expected_loss.append(loss.cpu()) - - # Setup with symbolic shape inference - torch.manual_seed(seed) - set_seed(seed) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - options.utils.run_symbolic_shape_infer = True - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - new_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - new_loss.append(loss.cpu()) - - # Setup with symbolic shape inference in legacy API - torch.manual_seed(seed) - set_seed(seed) - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer( - model, - my_loss, - model_desc, - "LambOptimizer", - None, - lr_desc, - device=device, - run_symbolic_shape_infer=True, - _use_deterministic_compute=True, - ) - # Training loop - legacy_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) - legacy_loss.append(loss.cpu()) - - # Compare losses - _test_helpers.assert_model_outputs(new_loss, expected_loss) - _test_helpers.assert_model_outputs(legacy_loss, expected_loss) - - -@pytest.mark.parametrize( - "test_input", - [ - ( - { - "distributed": {"enable_adasum": True}, - } - ) - ], -) -def testORTTrainerOptionsEnabledAdasumFlag(test_input): - """Test the enabled_adasum flag values when set enabled""" - - actual_values = orttrainer_options.ORTTrainerOptions(test_input) - assert actual_values.distributed.enable_adasum is True - - -@pytest.mark.parametrize( - "test_input", - [ - ( - { - "distributed": {"enable_adasum": False}, - } - ) - ], -) -def testORTTrainerOptionsDisabledAdasumFlag(test_input): - """Test the enabled_adasum flag values when set disabled""" - - actual_values = orttrainer_options.ORTTrainerOptions(test_input) - assert actual_values.distributed.enable_adasum is False - - -def testORTTrainerUnusedInput(): - class UnusedInputModel(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.mean(x) - - model = UnusedInputModel() - model_desc = {"inputs": [("x", [1]), ("y", [1])], "outputs": [("loss", [], True)]} - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config) - - # Run just one step to make sure there are no iobinding errors for the unused input. - try: - trainer.train_step(torch.FloatTensor([1.0]), torch.FloatTensor([1.0])) - except RuntimeError: - pytest.fail("RuntimeError doing train_step with an unused input.") - - -@pytest.mark.parametrize( - "debug_files", - [ - { - "model_after_graph_transforms_path": "transformed.onnx", - "model_with_gradient_graph_path": "transformed_grad.onnx", - "model_with_training_graph_path": "training.onnx", - "model_with_training_graph_after_optimization_path": "training_optimized.onnx", - }, - {"model_after_graph_transforms_path": "transformed.onnx", "model_with_training_graph_path": ""}, - ], -) -def testTrainingGraphExport(debug_files): - device = "cuda" - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - - with tempfile.TemporaryDirectory() as tempdir: - debug_paths = {} - for k, v in debug_files.items(): - debug_paths[k] = os.path.join(tempdir, v) - opts = orttrainer.ORTTrainerOptions({"device": {"id": device}, "debug": {"graph_save_paths": debug_paths}}) - optim_config = optim.AdamConfig() - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - data, targets = batcher_fn(train_data, 0) - trainer.train_step(data, targets) - for k, v in debug_files.items(): - path = debug_paths[k] - if len(v) > 0: - assert os.path.isfile(path) - saved_graph = onnx.load(path).graph - if k == "model_with_training_graph_path": - assert any("AdamOptimizer" in n.op_type for n in saved_graph.node) - elif k == "model_with_gradient_graph_path": - assert any("Grad" in n.name for n in saved_graph.node) - elif k == "model_after_graph_transforms_path": - assert any("LayerNormalization" in n.op_type for n in saved_graph.node) - elif k == "model_with_training_graph_after_optimization_path": - assert any("FusedMatMul" in n.op_type for n in saved_graph.node) - # remove saved file - os.remove(path) - else: - assert not os.path.isfile(path) - - -def _adam_max_norm_clip_data(): - device_capability_major = torch.cuda.get_device_capability()[0] - if device_capability_major == 7: # V100 for Dev machine - return [ - ( - 0, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.592951, - 10.067989, - 9.619152, - 9.245731, - 8.881137, - 8.578644, - 8.280573, - 8.063023, - 7.797933, - 7.486215, - 7.233806, - 7.011791, - ], - 14: [ - 10.584141, - 10.068119, - 9.581743, - 9.191472, - 8.880169, - 8.5352, - 8.311425, - 8.061202, - 7.773032, - 7.523009, - 7.258711, - 7.02805, - ], - }, - ), - ( - 0, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.592951, - 10.068722, - 9.620503, - 9.247791, - 8.883972, - 8.582286, - 8.285027, - 8.068308, - 7.803638, - 7.492318, - 7.240352, - 7.018665, - ], - 14: [ - 10.584141, - 10.068845, - 9.583107, - 9.193537, - 8.882966, - 8.538839, - 8.315872, - 8.066408, - 7.778978, - 7.529708, - 7.265849, - 7.035439, - ], - }, - ), - ( - 42, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.647908, - 10.144501, - 9.672352, - 9.306980, - 8.956026, - 8.602655, - 8.351079, - 8.088144, - 7.867220, - 7.564082, - 7.289846, - 7.073726, - ], - 14: [ - 10.697515, - 10.229034, - 9.765422, - 9.428294, - 9.080612, - 8.715208, - 8.459574, - 8.169073, - 7.940211, - 7.654147, - 7.390446, - 7.166227, - ], - }, - ), - ( - 42, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.647908, - 10.145191, - 9.673690, - 9.309031, - 8.959020, - 8.606632, - 8.355836, - 8.093478, - 7.873327, - 7.570731, - 7.296772, - 7.0809422, - ], - 14: [ - 10.697515, - 10.22967, - 9.766556, - 9.430037, - 9.083106, - 8.718601, - 8.463726, - 8.17396, - 7.945755, - 7.660188, - 7.396963, - 7.172944, - ], - }, - ), - ] - elif device_capability_major == 5: # M60 for CI machines (Python Packaging Pipeline) - return [ - ( - 0, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.618382, - 10.08292, - 9.603334, - 9.258133, - 8.917768, - 8.591574, - 8.318401, - 8.042292, - 7.783608, - 7.50226, - 7.236041, - 7.035602, - ], - 14: [ - 10.618382, - 10.08292, - 9.603334, - 9.258133, - 8.917768, - 8.591574, - 8.318401, - 8.042292, - 7.783608, - 7.50226, - 7.236041, - 7.035602, - ], - }, - ), - ( - 0, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.618382, - 10.083632, - 9.604639, - 9.260109, - 8.920504, - 8.595082, - 8.322799, - 8.047493, - 7.78929, - 7.508382, - 7.242587, - 7.042367, - ], - 14: [ - 10.618382, - 10.083632, - 9.604639, - 9.260109, - 8.920504, - 8.595082, - 8.322799, - 8.047493, - 7.78929, - 7.508382, - 7.242587, - 7.042367, - ], - }, - ), - ( - 42, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.68639, - 10.102986, - 9.647681, - 9.293091, - 8.958928, - 8.625297, - 8.351107, - 8.079577, - 7.840723, - 7.543044, - 7.284141, - 7.072688, - ], - 14: [ - 10.68639, - 10.102986, - 9.647681, - 9.293091, - 8.958928, - 8.625297, - 8.351107, - 8.079577, - 7.840723, - 7.543044, - 7.284141, - 7.072688, - ], - }, - ), - ( - 42, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.68639, - 10.103672, - 9.649025, - 9.295167, - 8.961777, - 8.629059, - 8.355571, - 8.084871, - 7.846589, - 7.549438, - 7.290722, - 7.079446, - ], - 14: [ - 10.697515, - 10.22967, - 9.766556, - 9.430037, - 9.083106, - 8.718601, - 8.463726, - 8.17396, - 7.945755, - 7.660188, - 7.396963, - 7.172944, - ], - }, - ), - ] - - -@pytest.mark.parametrize( - "seed,device,max_norm_clip,gradient_accumulation_steps,total_steps,expected_loss", _adam_max_norm_clip_data() -) -def testORTTrainerAdamMaxNormClip(seed, device, max_norm_clip, gradient_accumulation_steps, total_steps, expected_loss): - rtol = 1e-5 - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.AdamConfig(lr=0.001, max_norm_clip=max_norm_clip) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu().item()) - - # Compare legacy vs experimental APIs - assert trainer._onnx_model is not None - opset = get_model_opset(trainer._onnx_model) - _test_helpers.assert_model_outputs(expected_loss[opset], actual_loss, rtol=rtol) - - -def _lamb_max_norm_clip_data(): - device_capability_major = torch.cuda.get_device_capability()[0] - if device_capability_major == 7: # V100 for Dev machine - return [ - ( - 0, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.592951, - 10.487728, - 10.422251, - 10.350913, - 10.244248, - 10.213003, - 10.129222, - 10.095112, - 10.035983, - 9.974586, - 9.909771, - 9.874278, - ], - 14: [ - 10.584141, - 10.497192, - 10.389251, - 10.286045, - 10.231354, - 10.17018, - 10.066779, - 10.048138, - 9.958029, - 9.8908, - 9.82965, - 9.755484, - ], - }, - ), - ( - 0, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.592951, - 10.452503, - 10.349832, - 10.245314, - 10.106587, - 10.046009, - 9.934781, - 9.875164, - 9.792067, - 9.704592, - 9.617104, - 9.563070, - ], - 14: [ - 10.584141, - 10.461154, - 10.315399, - 10.178979, - 10.092329, - 9.999928, - 9.869949, - 9.824564, - 9.707565, - 9.61643, - 9.532847, - 9.439593, - ], - }, - ), - ( - 42, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.647908, - 10.566276, - 10.476154, - 10.406275, - 10.311079, - 10.240053, - 10.196469, - 10.113955, - 10.117376, - 10.013077, - 9.930301, - 9.893368, - ], - 14: [ - 10.697515, - 10.631279, - 10.528757, - 10.496689, - 10.411219, - 10.322109, - 10.297314, - 10.215549, - 10.149698, - 10.087336, - 10.010884, - 9.934544, - ], - }, - ), - ( - 42, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.647908, - 10.531957, - 10.405246, - 10.302971, - 10.176583, - 10.075583, - 10.005772, - 9.897825, - 9.875748, - 9.748932, - 9.642885, - 9.586762, - ], - 14: [ - 10.697515, - 10.596729, - 10.457815, - 10.393475, - 10.277581, - 10.158909, - 10.108126, - 10.000326, - 9.912526, - 9.826057, - 9.727899, - 9.633768, - ], - }, - ), - ] - elif device_capability_major == 5: # M60 for CI machines (Python Packaging Pipeline) - return [ - ( - 0, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.618382, - 10.50222, - 10.403347, - 10.35298, - 10.288447, - 10.237399, - 10.184225, - 10.089048, - 10.008952, - 9.972644, - 9.897674, - 9.84524, - ], - 14: [0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], - }, - ), - ( - 0, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.618382, - 10.466732, - 10.330871, - 10.24715, - 10.150972, - 10.069127, - 9.98974, - 9.870169, - 9.763693, - 9.704323, - 9.605957, - 9.533117, - ], - 14: [1, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], - }, - ), - ( - 42, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.68639, - 10.511692, - 10.447308, - 10.405255, - 10.334866, - 10.261473, - 10.169422, - 10.107138, - 10.069889, - 9.97798, - 9.928105, - 9.896435, - ], - 14: [2, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], - }, - ), - ( - 42, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.68639, - 10.477489, - 10.376671, - 10.301725, - 10.200718, - 10.098477, - 9.97995, - 9.890104, - 9.828899, - 9.713555, - 9.639567, - 9.589856, - ], - 14: [3, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], - }, - ), - ] - - -@pytest.mark.parametrize( - "seed,device,max_norm_clip, gradient_accumulation_steps,total_steps,expected_loss", _lamb_max_norm_clip_data() -) -def testORTTrainerLambMaxNormClip(seed, device, max_norm_clip, gradient_accumulation_steps, total_steps, expected_loss): - rtol = 1e-3 - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001, max_norm_clip=max_norm_clip) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu().item()) - - # Compare legacy vs experimental APIs - opset = get_model_opset(trainer._onnx_model) - _test_helpers.assert_model_outputs(expected_loss[opset], actual_loss, rtol=rtol) diff --git a/orttraining/orttraining/test/python/orttraining_test_transformers.py b/orttraining/orttraining/test/python/orttraining_test_transformers.py deleted file mode 100644 index dbaf4a293c466..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_transformers.py +++ /dev/null @@ -1,480 +0,0 @@ -import random -import unittest - -import numpy as np -import torch -from numpy.testing import assert_allclose -from orttraining_test_data_loader import BatchArgsOption, ids_tensor -from orttraining_test_utils import get_lr, run_test -from transformers import BertConfig, BertForPreTraining - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer # noqa: F401 - - -class BertModelTest(unittest.TestCase): - class BertModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=True, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=5, - num_attention_heads=4, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - scope=None, - device="cpu", - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.scope = scope - self.device = device - - # 1. superset of bert input/output descs - # see BertPreTrainedModel doc - self.input_ids_desc = IODescription( - "input_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.vocab_size - ) - self.attention_mask_desc = IODescription( - "attention_mask", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2 - ) - self.token_type_ids_desc = IODescription( - "token_type_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2 - ) - self.position_ids_desc = IODescription( - "position_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.max_position_embeddings - ) - self.head_mask_desc = IODescription( - "head_mask", [self.num_hidden_layers, self.num_attention_heads], torch.int64, num_classes=2 - ) - self.inputs_embeds_desc = IODescription( - "inputs_embeds", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32 - ) - - self.encoder_hidden_states_desc = IODescription( - "encoder_hidden_states", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32 - ) - self.encoder_attention_mask_desc = IODescription( - "encoder_attention_mask", ["batch", "max_seq_len_in_batch"], torch.float32 - ) - - # see BertForPreTraining doc - self.masked_lm_labels_desc = IODescription( - "masked_lm_labels", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.vocab_size - ) - self.next_sentence_label_desc = IODescription( - "next_sentence_label", - [ - "batch", - ], - torch.int64, - num_classes=2, - ) - - # outputs - self.loss_desc = IODescription( - "loss", - [ - 1, - ], - torch.float32, - ) - self.prediction_scores_desc = IODescription( - "prediction_scores", ["batch", "max_seq_len_in_batch", self.vocab_size], torch.float32 - ) - - self.seq_relationship_scores_desc = IODescription( - "seq_relationship_scores", ["batch", 2], torch.float32 - ) # IODescription('seq_relationship_scores', ['batch', 'max_seq_len_in_batch', 2], torch.float32) - self.hidden_states_desc = IODescription( - "hidden_states", - [self.num_hidden_layers, "batch", "max_seq_len_in_batch", self.hidden_size], - torch.float32, - ) - self.attentions_desc = IODescription( - "attentions", - [ - self.num_hidden_layers, - "batch", - self.num_attention_heads, - "max_seq_len_in_batch", - "max_seq_len_in_batch", - ], - torch.float32, - ) - self.last_hidden_state_desc = IODescription( - "last_hidden_state", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32 - ) - self.pooler_output_desc = IODescription("pooler_output", ["batch", self.hidden_size], torch.float32) - - def BertForPreTraining_descs(self): - return ModelDescription( - [ - self.input_ids_desc, - self.attention_mask_desc, - self.token_type_ids_desc, - self.masked_lm_labels_desc, - self.next_sentence_label_desc, - ], - # returns loss_desc if both masked_lm_labels_desc, next_sentence_label are provided - # hidden_states_desc, attentions_desc shall be included according to config.output_attentions, config.output_hidden_states - [ - self.loss_desc, - self.prediction_scores_desc, - self.seq_relationship_scores_desc, - # hidden_states_desc, attentions_desc - ], - ) - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).to(self.device) - - input_mask = None - if self.use_input_mask: - input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2).to(self.device) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size).to(self.device) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size).to(self.device) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels).to(self.device) - choice_labels = ids_tensor([self.batch_size], self.num_choices).to(self.device) - - config = BertConfig( - vocab_size=self.vocab_size, - vocab_size_or_config_json_file=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - ) - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def create_and_check_bert_for_pretraining( - self, - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - option_use_internal_get_lr_this_step=[True], # noqa: B006 - option_use_internal_loss_scaler=[True], # noqa: B006 - ): - seed = 42 - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - onnxruntime.set_seed(seed) - - model = BertForPreTraining(config=config) - model.eval() - loss, prediction_scores, seq_relationship_score = model( - input_ids, - attention_mask=input_mask, - token_type_ids=token_type_ids, - masked_lm_labels=token_labels, - next_sentence_label=sequence_labels, - ) - model_desc = ModelDescription( - [ - self.input_ids_desc, - self.attention_mask_desc, - self.token_type_ids_desc, - self.masked_lm_labels_desc, - self.next_sentence_label_desc, - ], - [self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc], - ) - - from collections import namedtuple - - MyArgs = namedtuple( - "MyArgs", "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len" - ) - - dataset_len = 100 - epochs = 8 - max_steps = epochs * dataset_len - args = MyArgs( - local_rank=0, - world_size=1, - max_steps=max_steps, - learning_rate=0.00001, - warmup_proportion=0.01, - batch_size=13, - seq_len=7, - ) - - def get_lr_this_step(global_step): - return get_lr(args, global_step) - - loss_scaler = LossScaler("loss_scale_input_name", True, up_scale_window=2000) - - for fp16 in option_fp16: - for allreduce_post_accumulation in option_allreduce_post_accumulation: - for gradient_accumulation_steps in option_gradient_accumulation_steps: - for use_internal_get_lr_this_step in option_use_internal_get_lr_this_step: - for use_internal_loss_scaler in option_use_internal_loss_scaler: - for split_batch in option_split_batch: - print("gradient_accumulation_steps:", gradient_accumulation_steps) - print("split_batch:", split_batch) - - seed = 42 - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - onnxruntime.set_seed(seed) - - ( - old_api_loss_ort, - old_api_prediction_scores_ort, - old_api_seq_relationship_score_ort, - ) = run_test( - model, - model_desc, - self.device, - args, - gradient_accumulation_steps, - fp16, - allreduce_post_accumulation, - get_lr_this_step, - use_internal_get_lr_this_step, - loss_scaler, - use_internal_loss_scaler, - split_batch, - dataset_len, - epochs, - use_new_api=False, - ) - - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - onnxruntime.set_seed(seed) - if use_internal_get_lr_this_step and use_internal_loss_scaler: - ( - new_api_loss_ort, - new_api_prediction_scores_ort, - new_api_seq_relationship_score_ort, - ) = run_test( - model, - model_desc, - self.device, - args, - gradient_accumulation_steps, - fp16, - allreduce_post_accumulation, - get_lr_this_step, - use_internal_get_lr_this_step, - loss_scaler, - use_internal_loss_scaler, - split_batch, - dataset_len, - epochs, - use_new_api=True, - ) - - assert_allclose(old_api_loss_ort, new_api_loss_ort) - assert_allclose(old_api_prediction_scores_ort, new_api_prediction_scores_ort) - assert_allclose( - old_api_seq_relationship_score_ort, new_api_seq_relationship_score_ort - ) - - def setUp(self): - self.model_tester = BertModelTest.BertModelTester(self) - - def test_for_pretraining_mixed_precision(self): - # It would be better to test both with/without mixed precision and allreduce_post_accumulation. - # However, stress test of all the 4 cases is not stable at least on the test machine. - # There we only test mixed precision and allreduce_post_accumulation because it is the most useful use cases. - option_fp16 = [True] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1] - option_split_batch = [BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_mixed_precision_with_gradient_accumulation(self): - # It would be better to test both with/without mixed precision and allreduce_post_accumulation. - # However, stress test of all the 4 cases is not stable at least on the test machine. - # There we only test mixed precision and allreduce_post_accumulation because it is the most useful use cases. - option_fp16 = [True] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [8] - option_split_batch = [BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_all(self): - # This test is not stable because it create and run ORTSession multiple times. - # It occasionally gets seg fault at ~MemoryPattern() - # when releasing patterns_. In order not to block PR merging CI test, - # this test is broke into following individual tests. - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1, 8] - option_split_batch = [BatchArgsOption.List, BatchArgsOption.Dict, BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_list_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1] - option_split_batch = [BatchArgsOption.List] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_dict_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1] - option_split_batch = [BatchArgsOption.Dict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_list_and_dict_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1] - option_split_batch = [BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_grad_accumulation_list_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [8] - option_split_batch = [BatchArgsOption.List] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_grad_accumulation_dict_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [8] - option_split_batch = [BatchArgsOption.Dict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_grad_accumulation_list_and_dict_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [8] - option_split_batch = [BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_test_utils.py b/orttraining/orttraining/test/python/orttraining_test_utils.py deleted file mode 100644 index 527cfb8a0ba7d..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_utils.py +++ /dev/null @@ -1,246 +0,0 @@ -import math - -import torch -from orttraining_test_data_loader import BatchArgsOption, create_ort_test_dataloader, split_batch - -from onnxruntime.capi.ort_trainer import IODescription, ORTTrainer -from onnxruntime.training import amp, optim, orttrainer -from onnxruntime.training.optim import _LRScheduler - - -def warmup_cosine(x, warmup=0.002): - if x < warmup: - return x / warmup - return 0.5 * (1.0 + torch.cos(math.pi * x)) - - -def warmup_constant(x, warmup=0.002): - if x < warmup: - return x / warmup - return 1.0 - - -def warmup_linear(x, warmup=0.002): - if x < warmup: - return x / warmup - return max((x - 1.0) / (warmup - 1.0), 0.0) - - -def warmup_poly(x, warmup=0.002, degree=0.5): - if x < warmup: - return x / warmup - return (1.0 - x) ** degree - - -SCHEDULES = { - "warmup_cosine": warmup_cosine, - "warmup_constant": warmup_constant, - "warmup_linear": warmup_linear, - "warmup_poly": warmup_poly, -} - - -def get_lr(args, training_steps, schedule="warmup_poly"): - if args.max_steps == -1: - return args.learning_rate - - schedule_fct = SCHEDULES[schedule] - return args.learning_rate * schedule_fct(training_steps / args.max_steps, args.warmup_proportion) - - -def map_optimizer_attributes(name): - no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] - no_decay = any(no_decay_key in name for no_decay_key in no_decay_keys) - if no_decay: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} - else: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} - - -class WrapLRScheduler(_LRScheduler): - def __init__(self, get_lr_this_step): - super().__init__() - self.get_lr_this_step = get_lr_this_step - - def get_lr(self, train_step_info): - return [self.get_lr_this_step(train_step_info.optimization_step)] - - -def run_test( - model, - model_desc, - device, - args, - gradient_accumulation_steps, - fp16, - allreduce_post_accumulation, - get_lr_this_step, - use_internal_get_lr_this_step, - loss_scaler, - use_internal_loss_scaler, - batch_args_option, - dataset_len, - epochs, - use_new_api, -): - dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device) - - if use_new_api: - assert use_internal_loss_scaler, "new api should always use internal loss scaler" - - new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step) - - new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None - options = orttrainer.ORTTrainerOptions( - { - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "device": {"id": device}, - "mixed_precision": {"enabled": fp16, "loss_scaler": new_api_loss_scaler}, - "debug": { - "deterministic_compute": True, - }, - "utils": {"grad_norm_clip": True}, - "distributed": {"allreduce_post_accumulation": True}, - "lr_scheduler": new_api_lr_scheduler, - } - ) - - param_optimizer = list(model.named_parameters()) - params = [ - { - "params": [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n], - "alpha": 0.9, - "beta": 0.999, - "lambda": 0.0, - "epsilon": 1e-6, - }, - { - "params": [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)], - "alpha": 0.9, - "beta": 0.999, - "lambda": 0.0, - "epsilon": 1e-6, - }, - ] - - vocab_size = 99 - new_model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", "max_seq_len_in_batch"], - ), - ( - "token_type_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "masked_lm_labels", - ["batch", "max_seq_len_in_batch"], - ), - ( - "next_sentence_label", - [ - "batch", - ], - ), - ], - "outputs": [ - ( - "loss", - [ - 1, - ], - True, - ), - ("prediction_scores", ["batch", "max_seq_len_in_batch", vocab_size]), - ("seq_relationship_scores", ["batch", 2]), - ], - } - - optim_config = optim.LambConfig(params=params, lr=2e-5) - model = orttrainer.ORTTrainer(model, new_model_desc, optim_config, options=options) - print("running with new frontend API") - else: - model = ORTTrainer( - model, - None, - model_desc, - "LambOptimizer", - map_optimizer_attributes=map_optimizer_attributes, - learning_rate_description=IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ), - device=device, - _enable_internal_postprocess=True, - gradient_accumulation_steps=gradient_accumulation_steps, - # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6 - world_rank=args.local_rank, - world_size=args.world_size, - use_mixed_precision=fp16, - allreduce_post_accumulation=allreduce_post_accumulation, - get_lr_this_step=get_lr_this_step if use_internal_get_lr_this_step else None, - loss_scaler=loss_scaler if use_internal_loss_scaler else None, - _opset_version=14, - _use_deterministic_compute=True, - ) - print("running with old frontend API") - - # training loop - eval_batch = None - if not use_new_api: - model.train() - for _epoch in range(epochs): - for step, batch in enumerate(dataloader): - if eval_batch is None: - eval_batch = batch - - if not use_internal_get_lr_this_step: - lr = get_lr_this_step(step) - learning_rate = torch.tensor([lr]) - - if not use_internal_loss_scaler and fp16: - loss_scale = torch.tensor([loss_scaler.loss_scale_]) - - if batch_args_option == BatchArgsOption.List: - if not use_internal_get_lr_this_step: - batch = [*batch, learning_rate] # noqa: PLW2901 - if not use_internal_loss_scaler and fp16: - batch = [*batch, loss_scale] # noqa: PLW2901 - outputs = model.train_step(*batch) - elif batch_args_option == BatchArgsOption.Dict: - args, kwargs = split_batch(batch, model_desc.inputs_, 0) - if not use_internal_get_lr_this_step: - kwargs["Learning_Rate"] = learning_rate - if not use_internal_loss_scaler and fp16: - kwargs[model.loss_scale_input_name] = loss_scale - outputs = model.train_step(*args, **kwargs) - else: - args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs - args, kwargs = split_batch(batch, model_desc.inputs_, args_count) - if not use_internal_get_lr_this_step: - kwargs["Learning_Rate"] = learning_rate - if not use_internal_loss_scaler and fp16: - kwargs[model.loss_scale_input_name] = loss_scale - outputs = model.train_step(*args, **kwargs) - - # eval - if batch_args_option == BatchArgsOption.List: - outputs = model.eval_step(*batch) - elif batch_args_option == BatchArgsOption.Dict: - args, kwargs = split_batch(batch, model_desc.inputs_, 0) - outputs = model.eval_step(*args, **kwargs) - else: - args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs - args, kwargs = split_batch(batch, model_desc.inputs_, args_count) - outputs = model.eval_step(*args, **kwargs) - - return (output.cpu().numpy() for output in outputs) diff --git a/orttraining/orttraining/test/python/orttraining_transformer_trainer.py b/orttraining/orttraining/test/python/orttraining_transformer_trainer.py deleted file mode 100644 index bce726871bacf..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_transformer_trainer.py +++ /dev/null @@ -1,357 +0,0 @@ -# adapted from Trainer.py of huggingface transformers - -import json -import logging -import os -import random -from typing import Callable, Dict, List, NamedTuple, Optional - -import numpy as np -import torch -from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import Dataset -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import SequentialSampler -from tqdm import tqdm, trange -from transformers.data.data_collator import DefaultDataCollator -from transformers.modeling_utils import PreTrainedModel -from transformers.training_args import TrainingArguments - -import onnxruntime -from onnxruntime.training import amp, optim, orttrainer - -try: - from torch.utils.tensorboard import SummaryWriter - - _has_tensorboard = True -except ImportError: - try: - from tensorboardX import SummaryWriter # noqa: F401 - - _has_tensorboard = True - except ImportError: - _has_tensorboard = False - - -def is_tensorboard_available(): - return _has_tensorboard - - -logger = logging.getLogger(__name__) - - -def set_seed(seed: int): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - onnxruntime.set_seed(seed) - - -class EvalPrediction(NamedTuple): - predictions: np.ndarray - label_ids: np.ndarray - - -class PredictionOutput(NamedTuple): - predictions: np.ndarray - label_ids: Optional[np.ndarray] - metrics: Optional[Dict[str, float]] - - -class TrainOutput(NamedTuple): - global_step: int - training_loss: float - - -def get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps, base_lr): - def lr_lambda_linear(current_step): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) - - def lambda_lr_get_lr(current_global_step): - # LambdaLR increment self.last_epoch at evert sept() - return base_lr * lr_lambda_linear(current_global_step) - - return lambda_lr_get_lr - - -class ORTTransformerTrainer: - """ """ - - model: PreTrainedModel - args: TrainingArguments - train_dataset: Dataset - eval_dataset: Dataset - compute_metrics: Callable[[EvalPrediction], Dict] - - def __init__( - self, - model: PreTrainedModel, - model_desc: dict, - args: TrainingArguments, - train_dataset: Dataset, - eval_dataset: Dataset, - compute_metrics: Callable[[EvalPrediction], Dict], - world_size: Optional[int] = 1, - ): - """ """ - - self.model = model - self.model_desc = model_desc - self.args = args - self.world_size = world_size - self.data_collator = DefaultDataCollator() - self.train_dataset = train_dataset - self.eval_dataset = eval_dataset - self.compute_metrics = compute_metrics - set_seed(self.args.seed) - # Create output directory if needed - if self.args.local_rank in [-1, 0]: - os.makedirs(self.args.output_dir, exist_ok=True) - - def get_train_dataloader(self) -> DataLoader: - if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") - train_sampler = ( - SequentialSampler(self.train_dataset) - if self.args.local_rank == -1 - else DistributedSampler(self.train_dataset) - ) - return DataLoader( - self.train_dataset, - batch_size=self.args.train_batch_size, - sampler=train_sampler, - collate_fn=self.data_collator.collate_batch, - ) - - def get_eval_dataloader(self) -> DataLoader: - return DataLoader( - self.eval_dataset, - batch_size=self.args.eval_batch_size, - shuffle=False, - collate_fn=self.data_collator.collate_batch, - ) - - def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: - # We use the same batch_size as for eval. - return DataLoader( - test_dataset, - batch_size=self.args.eval_batch_size, - shuffle=False, - collate_fn=self.data_collator.collate_batch, - ) - - def train(self): - """ - Main training entry point. - """ - train_dataloader = self.get_train_dataloader() - - if self.args.max_steps > 0: - t_total = self.args.max_steps - num_train_epochs = ( - self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1 - ) - else: - t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) - num_train_epochs = self.args.num_train_epochs - - lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(t_total, self.args.warmup_steps / float(t_total)) - - loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None - device = self.args.device.type - - device = f"{device}:{self.args.device.index}" if self.args.device.index else f"{device}:0" - options = orttrainer.ORTTrainerOptions( - { - "batch": {"gradient_accumulation_steps": self.args.gradient_accumulation_steps}, - "device": {"id": device}, - "mixed_precision": {"enabled": self.args.fp16, "loss_scaler": loss_scaler}, - "debug": { - "deterministic_compute": True, - }, - "utils": {"grad_norm_clip": False}, - "distributed": { - # we are running single node multi gpu test. thus world_rank = local_rank - # and world_size = self.args.n_gpu - "world_rank": max(0, self.args.local_rank), - "world_size": int(self.world_size), - "local_rank": max(0, self.args.local_rank), - "allreduce_post_accumulation": True, - }, - "lr_scheduler": lr_scheduler, - } - ) - - param_optimizer = list(self.model.named_parameters()) - params = [ - { - "params": [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n], - "weight_decay_mode": 1, - }, - { - "params": [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)], - "weight_decay_mode": 1, - }, - ] - - optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) - self.model = orttrainer.ORTTrainer(self.model, self.model_desc, optim_config, options=options) - - # Train! - logger.info("***** Running training *****") - logger.info(" Num examples = %d", len(train_dataloader.dataset)) - logger.info(" Num Epochs = %d", num_train_epochs) - logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size) - logger.info( - " Total train batch size (w. parallel, distributed & accumulation) = %d", - self.args.train_batch_size - * self.args.gradient_accumulation_steps - * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1), - ) - logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) - logger.info(" Total optimization steps = %d", t_total) - - global_step = 0 - epochs_trained = 0 - steps_trained_in_current_epoch = 0 - - tr_loss = 0.0 - logging_loss = 0.0 - train_iterator = trange( - epochs_trained, - int(num_train_epochs), - desc="Epoch", - disable=self.args.local_rank not in [-1, 0], - ) - - for _epoch in train_iterator: - epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0]) - for step, inputs in enumerate(epoch_iterator): - # Skip past any already trained steps if resuming training - if steps_trained_in_current_epoch > 0: - steps_trained_in_current_epoch -= 1 - continue - - tr_loss += self._training_step(self.model, inputs) - - if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( - len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator) - ): - global_step += 1 - - if self.args.local_rank in [-1, 0]: - if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or ( - global_step == 1 and self.args.logging_first_step - ): - logs = {} - if self.args.evaluate_during_training: - results = self.evaluate() - for key, value in results.items(): - eval_key = f"eval_{key}" - logs[eval_key] = value - - loss_scalar = (tr_loss - logging_loss) / self.args.logging_steps - - logs["loss"] = loss_scalar - logging_loss = tr_loss - - epoch_iterator.write(json.dumps({**logs, **{"step": global_step}})) - - if self.args.max_steps > 0 and global_step > self.args.max_steps: - epoch_iterator.close() - break - if self.args.max_steps > 0 and global_step > self.args.max_steps: - train_iterator.close() - break - - logger.info("\n\nTraining completed. \n\n") - return TrainOutput(global_step, tr_loss / global_step) - - def _training_step(self, model, inputs: Dict[str, torch.Tensor]) -> float: - for k, v in inputs.items(): - inputs[k] = v.to(self.args.device) - - outputs = model.train_step(**inputs) - loss = outputs[0] # model outputs are always tuple in transformers (see doc) - - return loss.item() - - def save_model(self, output_dir: Optional[str] = None): - output_dir = output_dir if output_dir is not None else self.args.output_dir - os.makedirs(output_dir, exist_ok=True) - self.model.save_as_onnx(os.path.join(output_dir, "transformer.onnx")) - - def evaluate(self) -> Dict[str, float]: - """ - Run evaluation and return metrics. - - Returns: - A dict containing: - - the eval loss - - the potential metrics computed from the predictions - """ - eval_dataloader = self.get_eval_dataloader() - - output = self._prediction_loop(eval_dataloader, description="Evaluation") - return output.metrics - - def predict(self, test_dataset: Dataset) -> PredictionOutput: - """ - Run prediction and return predictions and potential metrics. - - Depending on the dataset and your use case, your test dataset may contain labels. - In that case, this method will also return metrics, like in evaluate(). - """ - test_dataloader = self.get_test_dataloader(test_dataset) - return self._prediction_loop(test_dataloader, description="Prediction") - - def _prediction_loop(self, dataloader: DataLoader, description: str) -> PredictionOutput: - """ - Prediction/evaluation loop, shared by `evaluate()` and `predict()`. - - Works both with or without labels. - """ - - logger.info("***** Running %s *****", description) - logger.info(" Num examples = %d", len(dataloader.dataset)) - logger.info(" Batch size = %d", dataloader.batch_size) - eval_losses: List[float] = [] - preds: np.ndarray = None - label_ids: np.ndarray = None - - for inputs in tqdm(dataloader, desc=description): - has_labels = any(inputs.get(k) is not None for k in ["labels", "masked_lm_labels"]) - - for k, v in inputs.items(): - inputs[k] = v.to(self.args.device) - - with torch.no_grad(): - outputs = self.model.eval_step(**inputs) - - if has_labels: - step_eval_loss, logits = outputs[:2] - eval_losses += [step_eval_loss.mean().item()] - else: - logits = outputs[0] - - if preds is None: - preds = logits.detach().cpu().numpy() - else: - preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) - if inputs.get("labels") is not None: - if label_ids is None: - label_ids = inputs["labels"].detach().cpu().numpy() - else: - label_ids = np.append(label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) - - if self.compute_metrics is not None and preds is not None and label_ids is not None: - metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) - else: - metrics = {} - if len(eval_losses) > 0: - metrics["loss"] = np.mean(eval_losses) - - return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) diff --git a/orttraining/orttraining/test/python/utils_multiple_choice.py b/orttraining/orttraining/test/python/utils_multiple_choice.py deleted file mode 100644 index e0febaf2d6334..0000000000000 --- a/orttraining/orttraining/test/python/utils_multiple_choice.py +++ /dev/null @@ -1,269 +0,0 @@ -# adapted from run_multiple_choice.py of huggingface transformers -# https://github.com/huggingface/transformers/blob/master/examples/multiple-choice/utils_multiple_choice.py - -import csv -import glob # noqa: F401 -import json # noqa: F401 -import logging -import os -from dataclasses import dataclass -from enum import Enum -from typing import List, Optional - -import torch -import tqdm -from filelock import FileLock -from torch.utils.data.dataset import Dataset -from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available # noqa: F401 - -logger = logging.getLogger(__name__) - - -@dataclass(frozen=True) -class InputExample: - """ - A single training/test example for multiple choice - - Args: - example_id: Unique id for the example. - question: string. The untokenized text of the second sequence (question). - contexts: list of str. The untokenized text of the first sequence (context of corresponding question). - endings: list of str. multiple choice's options. Its length must be equal to contexts' length. - label: (Optional) string. The label of the example. This should be - specified for train and dev examples, but not for test examples. - """ - - example_id: str - question: str - contexts: List[str] - endings: List[str] - label: Optional[str] - - -@dataclass(frozen=True) -class InputFeatures: - """ - A single set of features of data. - Property names are the same names as the corresponding inputs to a model. - """ - - example_id: str - input_ids: List[List[int]] - attention_mask: Optional[List[List[int]]] - token_type_ids: Optional[List[List[int]]] - label: Optional[int] - - -class Split(Enum): - train = "train" - dev = "dev" - test = "test" - - -class DataProcessor: - """Base class for data converters for multiple choice data sets.""" - - def get_train_examples(self, data_dir): - """Gets a collection of `InputExample`s for the train set.""" - raise NotImplementedError() - - def get_dev_examples(self, data_dir): - """Gets a collection of `InputExample`s for the dev set.""" - raise NotImplementedError() - - def get_test_examples(self, data_dir): - """Gets a collection of `InputExample`s for the test set.""" - raise NotImplementedError() - - def get_labels(self): - """Gets the list of labels for this data set.""" - raise NotImplementedError() - - -class MultipleChoiceDataset(Dataset): - """ - This will be superseded by a framework-agnostic approach - soon. - """ - - features: List[InputFeatures] - - def __init__( - self, - data_dir: str, - tokenizer: PreTrainedTokenizer, - task: str, - processor: DataProcessor, - max_seq_length: Optional[int] = None, - overwrite_cache=False, - mode: Split = Split.train, - ): - cached_features_file = os.path.join( - data_dir, - "cached_{}_{}_{}_{}".format( - mode.value, - tokenizer.__class__.__name__, - str(max_seq_length), - task, - ), - ) - - # Make sure only the first process in distributed training processes the dataset, - # and the others will use the cache. - lock_path = cached_features_file + ".lock" - with FileLock(lock_path): - if os.path.exists(cached_features_file) and not overwrite_cache: - logger.info(f"Loading features from cached file {cached_features_file}") - self.features = torch.load(cached_features_file) - else: - logger.info(f"Creating features from dataset file at {data_dir}") - label_list = processor.get_labels() - if mode == Split.dev: - examples = processor.get_dev_examples(data_dir) - elif mode == Split.test: - examples = processor.get_test_examples(data_dir) - else: - examples = processor.get_train_examples(data_dir) - logger.info("Training examples: %s", len(examples)) - # TODO clean up all this to leverage built-in features of tokenizers - self.features = convert_examples_to_features( - examples, - label_list, - max_seq_length, - tokenizer, - pad_on_left=bool(tokenizer.padding_side == "left"), - pad_token=tokenizer.pad_token_id, - pad_token_segment_id=tokenizer.pad_token_type_id, - ) - logger.info("Saving features into cached file %s", cached_features_file) - torch.save(self.features, cached_features_file) - - def __len__(self): - return len(self.features) - - def __getitem__(self, i) -> InputFeatures: - return self.features[i] - - -class SwagProcessor(DataProcessor): - """Processor for the SWAG data set.""" - - def get_train_examples(self, data_dir): - """See base class.""" - logger.info(f"LOOKING AT {data_dir} train") - return self._create_examples(self._read_csv(os.path.join(data_dir, "train.csv")), "train") - - def get_dev_examples(self, data_dir): - """See base class.""" - logger.info(f"LOOKING AT {data_dir} dev") - return self._create_examples(self._read_csv(os.path.join(data_dir, "val.csv")), "dev") - - def get_test_examples(self, data_dir): - """See base class.""" - logger.info(f"LOOKING AT {data_dir} dev") - raise ValueError( - "For swag testing, the input file does not contain a label column. It can not be tested in current code" - "setting!" - ) - return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test") - - def get_labels(self): - """See base class.""" - return ["0", "1", "2", "3"] - - def _read_csv(self, input_file): - with open(input_file, encoding="utf-8") as f: - return list(csv.reader(f)) - - def _create_examples(self, lines: List[List[str]], type: str): - """Creates examples for the training and dev sets.""" - if type == "train" and lines[0][-1] != "label": - raise ValueError("For training, the input file must contain a label column.") - - examples = [ - InputExample( - example_id=line[2], - question=line[5], # in the swag dataset, the - # common beginning of each - # choice is stored in "sent2". - contexts=[line[4], line[4], line[4], line[4]], - endings=[line[7], line[8], line[9], line[10]], - label=line[11], - ) - for line in lines[1:] # we skip the line with the column names - ] - - return examples - - -def convert_examples_to_features( - examples: List[InputExample], - label_list: List[str], - max_length: int, - tokenizer: PreTrainedTokenizer, - pad_token_segment_id=0, - pad_on_left=False, - pad_token=0, - mask_padding_with_zero=True, -) -> List[InputFeatures]: - """ - Loads a data file into a list of `InputFeatures` - """ - - label_map = {label: i for i, label in enumerate(label_list)} - - features = [] - for ex_index, example in tqdm.tqdm(enumerate(examples), desc="convert examples to features"): - if ex_index % 10000 == 0: - logger.info("Writing example %d of %d" % (ex_index, len(examples))) - choices_inputs = [] - for _ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)): - text_a = context - if example.question.find("_") != -1: - # this is for cloze question - text_b = example.question.replace("_", ending) - else: - text_b = example.question + " " + ending - - inputs = tokenizer.encode_plus( - text_a, - text_b, - add_special_tokens=True, - max_length=max_length, - pad_to_max_length=True, - return_overflowing_tokens=True, - ) - if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0: - logger.info( - "Attention! you are cropping tokens (swag task is ok). " - "If you are training ARC and RACE and you are poping question + options," - "you need to try to use a bigger max seq length!" - ) - - choices_inputs.append(inputs) - - label = label_map[example.label] - - input_ids = [x["input_ids"] for x in choices_inputs] - attention_mask = ( - [x["attention_mask"] for x in choices_inputs] if "attention_mask" in choices_inputs[0] else None - ) - token_type_ids = ( - [x["token_type_ids"] for x in choices_inputs] if "token_type_ids" in choices_inputs[0] else None - ) - - features.append( - InputFeatures( - example_id=example.example_id, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - label=label, - ) - ) - - for f in features[:2]: - logger.info("*** Example ***") - logger.info("feature: %s" % f) - - return features diff --git a/orttraining/pytorch_frontend_examples/mnist_training.py b/orttraining/pytorch_frontend_examples/mnist_training.py deleted file mode 100644 index dc9b3f654400c..0000000000000 --- a/orttraining/pytorch_frontend_examples/mnist_training.py +++ /dev/null @@ -1,200 +0,0 @@ -## This code is from https://github.com/pytorch/examples/blob/master/mnist/main.py -## with modification to do training using onnxruntime as backend on cuda device. -## A private PyTorch build from https://aiinfra.visualstudio.com/Lotus/_git/pytorch (ORTTraining branch) is needed to run the demo. - -## Model testing is not complete. - -import argparse -import os - -import numpy as np # noqa: F401 -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim # noqa: F401 -from mpi4py import MPI -from torchvision import datasets, transforms - -from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer - -try: # noqa: SIM105 - from onnxruntime.capi._pybind_state import set_cuda_device_id -except ImportError: - pass - - -class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, x): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - return out - - -def my_loss(x, target): - return F.nll_loss(F.log_softmax(x, dim=1), target) - - -def train_with_trainer(args, trainer, device, train_loader, epoch): - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - - learning_rate = torch.tensor([args.lr]) - loss = trainer.train_step(data, target, learning_rate) - - # Since the output corresponds to [loss_desc, probability_desc], the first value is taken as loss. - if batch_idx % args.log_interval == 0: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, - batch_idx * len(data), - len(train_loader.dataset), - 100.0 * batch_idx / len(train_loader), - loss[0], - ) - ) - - -# TODO: comple this once ORT training can do evaluation. -def test_with_trainer(args, trainer, device, test_loader): - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - output = F.log_softmax(trainer.eval_step(data, fetches=["probability"]), dim=1) - test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print( - "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( - test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) - ) - ) - - -def mnist_model_description(): - input_desc = IODescription("input1", ["batch", 784], torch.float32) - label_desc = IODescription( - "label", - [ - "batch", - ], - torch.int64, - num_classes=10, - ) - loss_desc = IODescription("loss", [], torch.float32) - probability_desc = IODescription("probability", ["batch", 10], torch.float32) - return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc]) - - -def main(): - # Training settings - parser = argparse.ArgumentParser(description="PyTorch MNIST Example") - parser.add_argument( - "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" - ) - parser.add_argument("--epochs", type=int, default=10, metavar="N", help="number of epochs to train (default: 10)") - parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=10, - metavar="N", - help="how many batches to wait before logging training status", - ) - - args = parser.parse_args() - use_cuda = not args.no_cuda and torch.cuda.is_available() - - torch.manual_seed(args.seed) - - kwargs = {"num_workers": 0, "pin_memory": True} - train_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "../data", - train=True, - download=True, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.batch_size, - shuffle=True, - **kwargs, - ) - test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "../data", - train=False, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.test_batch_size, - shuffle=True, - **kwargs, - ) - - comm = MPI.COMM_WORLD - args.local_rank = ( - int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) if ("OMPI_COMM_WORLD_LOCAL_RANK" in os.environ) else 0 - ) - args.world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) if ("OMPI_COMM_WORLD_RANK" in os.environ) else 0 - args.world_size = comm.Get_size() - if use_cuda: - torch.cuda.set_device(args.local_rank) - device = torch.device("cuda", args.local_rank) - args.n_gpu = 1 - set_cuda_device_id(args.local_rank) - else: - device = torch.device("cpu") - - input_size = 784 - hidden_size = 500 - num_classes = 10 - model = NeuralNet(input_size, hidden_size, num_classes) - - model_desc = mnist_model_description() - # use log_interval as gradient accumulate steps - trainer = ORTTrainer( - model, - my_loss, - model_desc, - "SGDOptimizer", - None, - IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ), - device, - 1, - args.world_rank, - args.world_size, - use_mixed_precision=False, - allreduce_post_accumulation=True, - ) - print("\nBuild ort model done.") - - for epoch in range(1, args.epochs + 1): - train_with_trainer(args, trainer, device, train_loader, epoch) - test_with_trainer(args, trainer, device, test_loader) - - -if __name__ == "__main__": - main() diff --git a/samples/python/training/orttrainer/mnist/mnist_original.onnx b/samples/python/training/orttrainer/mnist/mnist_original.onnx deleted file mode 100644 index 15931affb5ccf..0000000000000 Binary files a/samples/python/training/orttrainer/mnist/mnist_original.onnx and /dev/null differ diff --git a/samples/python/training/orttrainer/mnist/ort_mnist.py b/samples/python/training/orttrainer/mnist/ort_mnist.py deleted file mode 100644 index 8f8ccf373ccf6..0000000000000 --- a/samples/python/training/orttrainer/mnist/ort_mnist.py +++ /dev/null @@ -1,174 +0,0 @@ -# This code is from https://github.com/pytorch/examples/blob/master/mnist/main.py -# with modification to do training using onnxruntime as backend on cuda device. - -import argparse -import os - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchvision import datasets, transforms - -import onnxruntime -from onnxruntime.training import ORTTrainer, ORTTrainerOptions, optim - - -# Pytorch model -class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, input1): - out = self.fc1(input1) - out = self.relu(out) - out = self.fc2(out) - return out - - -# ONNX Runtime training -def mnist_model_description(): - return { - "inputs": [("input1", ["batch", 784]), ("label", ["batch"])], - "outputs": [("loss", [], True), ("probability", ["batch", 10])], - } - - -def my_loss(x, target): - return F.nll_loss(F.log_softmax(x, dim=1), target) - - -# Helpers -def train(log_interval, trainer, device, train_loader, epoch, train_steps): - for batch_idx, (data, target) in enumerate(train_loader): - if batch_idx == train_steps: - break - - # Fetch data - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - - # Train step - loss, prob = trainer.train_step(data, target) - - # Stats - if batch_idx % log_interval == 0: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, batch_idx * len(data), len(train_loader.dataset), 100.0 * batch_idx / len(train_loader), loss - ) - ) - - -def test(trainer, device, test_loader): - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - - # Using fetches around without eval_step to not pass 'target' as input - trainer._train_step_info.fetches = ["probability"] - output = F.log_softmax(trainer.eval_step(data), dim=1) - trainer._train_step_info.fetches = [] - - # Stats - test_loss += F.nll_loss(output, target, reduction="sum").item() - pred = output.argmax(dim=1, keepdim=True) - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print( - "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( - test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) - ) - ) - - -def main(): - # Training settings - parser = argparse.ArgumentParser(description="ONNX Runtime MNIST Example") - parser.add_argument( - "--train-steps", - type=int, - default=-1, - metavar="N", - help="number of steps to train. Set -1 to run through whole dataset (default: -1)", - ) - parser.add_argument( - "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" - ) - parser.add_argument("--epochs", type=int, default=1, metavar="N", help="number of epochs to train (default: 1)") - parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=10, - metavar="N", - help="how many batches to wait before logging training status", - ) - parser.add_argument("--save-path", type=str, default="", help="Path for Saving the current Model state") - - # Basic setup - args = parser.parse_args() - if not args.no_cuda and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - torch.manual_seed(args.seed) - onnxruntime.set_seed(args.seed) - - # Data loader - train_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "./data", - train=True, - download=True, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.batch_size, - shuffle=True, - ) - - if args.test_batch_size > 0: - test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "./data", - train=False, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.test_batch_size, - shuffle=True, - ) - - # Modeling - model = NeuralNet(784, 500, 10) - model_desc = mnist_model_description() - optim_config = optim.SGDConfig(lr=args.lr) - opts = {"device": {"id": device}} - opts = ORTTrainerOptions(opts) - - trainer = ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - - # Train loop - for epoch in range(1, args.epochs + 1): - train(args.log_interval, trainer, device, train_loader, epoch, args.train_steps) - if args.test_batch_size > 0: - test(trainer, device, test_loader) - - # Save model - if args.save_path: - torch.save(model.state_dict(), os.path.join(args.save_path, "mnist_cnn.pt")) - - -if __name__ == "__main__": - main() diff --git a/samples/python/training/orttrainer/mnist/pytorch_mnist.py b/samples/python/training/orttrainer/mnist/pytorch_mnist.py deleted file mode 100644 index 2e451d85f62e8..0000000000000 --- a/samples/python/training/orttrainer/mnist/pytorch_mnist.py +++ /dev/null @@ -1,157 +0,0 @@ -import argparse -import os - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torchvision import datasets, transforms - - -# Pytorch model -class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, input1): - out = self.fc1(input1) - out = self.relu(out) - out = self.fc2(out) - return out - - -def my_loss(x, target, is_train=True): - if is_train: - return F.nll_loss(F.log_softmax(x, dim=1), target) - else: - return F.nll_loss(F.log_softmax(x, dim=1), target, reduction="sum") - - -# Helpers -def train(args, model, device, train_loader, optimizer, epoch): - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - if batch_idx == args.train_steps: - break - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - optimizer.zero_grad() - output = model(data) - loss = my_loss(output, target) - loss.backward() - optimizer.step() - if batch_idx % args.log_interval == 0: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, - batch_idx * len(data), - len(train_loader.dataset), - 100.0 * batch_idx / len(train_loader), - loss.item(), - ) - ) - - -def test(model, device, test_loader): - model.eval() - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - output = model(data) - # Stats - test_loss += my_loss(output, target, False).item() - pred = output.argmax(dim=1, keepdim=True) - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print( - "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( - test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) - ) - ) - - -def main(): - # Training settings - parser = argparse.ArgumentParser(description="PyTorch MNIST Example") - parser.add_argument( - "--train-steps", - type=int, - default=-1, - metavar="N", - help="number of steps to train. Set -1 to run through whole dataset (default: -1)", - ) - parser.add_argument( - "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" - ) - parser.add_argument("--epochs", type=int, default=1, metavar="N", help="number of epochs to train (default: 1)") - parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=10, - metavar="N", - help="how many batches to wait before logging training status", - ) - parser.add_argument("--save-path", type=str, default="", help="Path for Saving the current Model") - - # Basic setup - args = parser.parse_args() - if not args.no_cuda and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - torch.manual_seed(args.seed) - - # Data loader - train_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "./data", - train=True, - download=True, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.batch_size, - shuffle=True, - ) - - if args.test_batch_size > 0: - test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "./data", - train=False, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.test_batch_size, - shuffle=True, - ) - - # Modeling - model = NeuralNet(784, 500, 10).to(device) - optimizer = optim.SGD(model.parameters(), lr=args.lr) - - # Train loop - for epoch in range(1, args.epochs + 1): - train(args, model, device, train_loader, optimizer, epoch) - if args.test_batch_size > 0: - test(model, device, test_loader) - - # Save model - if args.save_path: - torch.save(model.state_dict(), os.path.join(args.save_path, "mnist_cnn.pt")) - - -if __name__ == "__main__": - main() diff --git a/samples/python/training/orttrainer/pytorch_transformer/README.md b/samples/python/training/orttrainer/pytorch_transformer/README.md deleted file mode 100644 index cda8cba6ca0ad..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/README.md +++ /dev/null @@ -1,33 +0,0 @@ -# TransformerModel example - -This example was adapted from Pytorch's [Sequence-to-Sequence Modeling with nn.Transformer and TorchText](https://pytorch.org/tutorials/beginner/transformer_tutorial.html) tutorial - -## Requirements - -* PyTorch 1.6+ -* TorchText 0.6+ -* ONNX Runtime 1.5+ - -## Running PyTorch version - -```bash -python pt_train.py -``` - -## Running ONNX Runtime version - -```bash -python ort_train.py -``` - -## Optional arguments - -| Argument | Description | Default | -| :---------------- | :-----------------------------------------------------: | --------: | -| --batch-size | input batch size for training | 20 | -| --test-batch-size | input batch size for testing | 20 | -| --epochs | number of epochs to train | 2 | -| --lr | learning rate | 0.001 | -| --no-cuda | disables CUDA training | False | -| --seed | random seed | 1 | -| --log-interval | how many batches to wait before logging training status | 200 | diff --git a/samples/python/training/orttrainer/pytorch_transformer/ort_train.py b/samples/python/training/orttrainer/pytorch_transformer/ort_train.py deleted file mode 100644 index 551e878cc9035..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/ort_train.py +++ /dev/null @@ -1,89 +0,0 @@ -import argparse - -import torch -from ort_utils import my_loss, transformer_model_description_dynamic_axes -from pt_model import TransformerModel -from utils import get_batch, prepare_data - -import onnxruntime - - -def train(trainer, data_source, device, epoch, args, bptt=35): - total_loss = 0.0 - for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)): - data, targets = get_batch(data_source, i) - - loss, pred = trainer.train_step(data, targets) - total_loss += loss.item() - if batch % args.log_interval == 0 and batch > 0: - cur_loss = total_loss / args.log_interval - print( - "epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}".format( - epoch, batch, len(data_source) // bptt, cur_loss - ) - ) - total_loss = 0 - - -def evaluate(trainer, data_source, bptt=35): - total_loss = 0.0 - with torch.no_grad(): - for i in range(0, data_source.size(0) - 1, bptt): - data, targets = get_batch(data_source, i) - loss, pred = trainer.eval_step(data, targets) - total_loss += len(data) * loss.item() - return total_loss / (len(data_source) - 1) - - -if __name__ == "__main__": - # Training settings - parser = argparse.ArgumentParser(description="PyTorch TransformerModel example") - parser.add_argument( - "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=20, metavar="N", help="input batch size for testing (default: 20)" - ) - parser.add_argument("--epochs", type=int, default=2, metavar="N", help="number of epochs to train (default: 2)") - parser.add_argument("--lr", type=float, default=0.001, metavar="LR", help="learning rate (default: 0.001)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=200, - metavar="N", - help="how many batches to wait before logging training status (default: 200)", - ) - - # Basic setup - args = parser.parse_args() - if not args.no_cuda and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - torch.manual_seed(args.seed) - onnxruntime.set_seed(args.seed) - - # Model - optim_config = onnxruntime.training.optim.SGDConfig(lr=args.lr) - model_desc = transformer_model_description_dynamic_axes() - model = TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device) - - # Preparing data - train_data, val_data, test_data = prepare_data(device, args.batch_size, args.test_batch_size) - trainer = onnxruntime.training.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss) - - # Train - for epoch in range(1, args.epochs + 1): - train(trainer, train_data, device, epoch, args) - val_loss = evaluate(trainer, val_data) - print("-" * 89) - print(f"| end of epoch {epoch:3d} | valid loss {val_loss:5.2f} | ") - print("-" * 89) - - # Evaluate - test_loss = evaluate(trainer, test_data) - print("=" * 89) - print(f"| End of training | test loss {test_loss:5.2f}") - print("=" * 89) diff --git a/samples/python/training/orttrainer/pytorch_transformer/ort_utils.py b/samples/python/training/orttrainer/pytorch_transformer/ort_utils.py deleted file mode 100644 index 73992f5596f5f..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/ort_utils.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch - -from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription -from onnxruntime.capi.ort_trainer import ModelDescription as Legacy_ModelDescription - - -def my_loss(x, target): - x = x.view(-1, 28785) - return torch.nn.CrossEntropyLoss()(x, target) - - -def transformer_model_description(bptt=35, batch_size=20, ntokens=28785): - model_desc = { - "inputs": [("input1", [bptt, batch_size]), ("label", [bptt * batch_size])], - "outputs": [("loss", [], True), ("predictions", [bptt, batch_size, ntokens])], - } - return model_desc - - -def transformer_model_description_dynamic_axes(ntokens=28785): - model_desc = { - "inputs": [("input1", ["bptt", "batch_size"]), ("label", ["bptt_x_batch_size"])], - "outputs": [("loss", [], True), ("predictions", ["bptt", "batch_size", ntokens])], - } - return model_desc - - -def legacy_transformer_model_description(bptt=35, batch_size=20, ntokens=28785): - input_desc = Legacy_IODescription("input1", [bptt, batch_size]) - label_desc = Legacy_IODescription("label", [bptt * batch_size]) - loss_desc = Legacy_IODescription("loss", []) - predictions_desc = Legacy_IODescription("predictions", [bptt, batch_size, ntokens]) - return ( - Legacy_ModelDescription([input_desc, label_desc], [loss_desc, predictions_desc]), - Legacy_IODescription("__learning_rate", [1]), - ) - - -def legacy_transformer_model_description_dynamic_axes(ntokens=28785): - input_desc = Legacy_IODescription("input1", ["bptt", "batch_size"]) - label_desc = Legacy_IODescription("label", ["bptt_x_batch_size"]) - loss_desc = Legacy_IODescription("loss", []) - predictions_desc = Legacy_IODescription("predictions", ["bptt", "batch_size", ntokens]) - return ( - Legacy_ModelDescription([input_desc, label_desc], [loss_desc, predictions_desc]), - Legacy_IODescription("__learning_rate", [1]), - ) diff --git a/samples/python/training/orttrainer/pytorch_transformer/pt_model.py b/samples/python/training/orttrainer/pytorch_transformer/pt_model.py deleted file mode 100644 index 4f2e03192c6cf..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/pt_model.py +++ /dev/null @@ -1,62 +0,0 @@ -import math - -import torch -import torch.nn as nn - - -class TransformerModel(nn.Module): - def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): - super().__init__() - from torch.nn import TransformerEncoder, TransformerEncoderLayer - - self.model_type = "Transformer" - self.input1_mask = None - self.pos_encoder = PositionalEncoding(ninp, dropout) - encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) - self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) - self.encoder = nn.Embedding(ntoken, ninp) - self.ninp = ninp - self.decoder = nn.Linear(ninp, ntoken) - - self.init_weights() - - def _generate_square_subsequent_mask(self, sz): - mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) - mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, 0.0) - return mask - - def init_weights(self): - initrange = 0.1 - self.encoder.weight.data.uniform_(-initrange, initrange) - self.decoder.bias.data.zero_() - self.decoder.weight.data.uniform_(-initrange, initrange) - - def forward(self, input1): - if self.input1_mask is None or self.input1_mask.size(0) != input1.size(0): - device = input1.device - mask = self._generate_square_subsequent_mask(input1.size(0)).to(device) - self.input1_mask = mask - - input1 = self.encoder(input1) * math.sqrt(self.ninp) - input1 = self.pos_encoder(input1) - output = self.transformer_encoder(input1, self.input1_mask) - output = self.decoder(output) - return output - - -class PositionalEncoding(nn.Module): - def __init__(self, d_model, dropout=0.1, max_len=5000): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - - pe = torch.zeros(max_len, d_model) - position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0).transpose(0, 1) - self.register_buffer("pe", pe) - - def forward(self, x): - x = x + self.pe[: x.size(0), :] - return self.dropout(x) diff --git a/samples/python/training/orttrainer/pytorch_transformer/pt_train.py b/samples/python/training/orttrainer/pytorch_transformer/pt_train.py deleted file mode 100644 index a197fb50357e9..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/pt_train.py +++ /dev/null @@ -1,94 +0,0 @@ -import argparse - -import torch -import torch.nn as nn -from pt_model import TransformerModel -from utils import get_batch, prepare_data - - -def train(model, data_source, device, epoch, args, bptt=35): - total_loss = 0.0 - model.train() - for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)): - data, targets = get_batch(data_source, i) - - optimizer.zero_grad() - output = model(data) - loss = criterion(output.view(-1, 28785), targets) - loss.backward() - optimizer.step() - - total_loss += loss.item() - if batch % args.log_interval == 0 and batch > 0: - cur_loss = total_loss / args.log_interval - print( - "epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}".format( - epoch, batch, len(data_source) // bptt, cur_loss - ) - ) - total_loss = 0 - - -def evaluate(model, data_source, criterion, bptt=35): - total_loss = 0.0 - model.eval() - with torch.no_grad(): - for i in range(0, data_source.size(0) - 1, bptt): - data, targets = get_batch(data_source, i) - output = model(data) - output_flat = output.view(-1, 28785) - total_loss += len(data) * criterion(output_flat, targets).item() - return total_loss / (len(data_source) - 1) - - -if __name__ == "__main__": - # Training settings - parser = argparse.ArgumentParser(description="PyTorch TransformerModel example") - parser.add_argument( - "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=20, metavar="N", help="input batch size for testing (default: 20)" - ) - parser.add_argument("--epochs", type=int, default=2, metavar="N", help="number of epochs to train (default: 2)") - parser.add_argument("--lr", type=float, default=0.001, metavar="LR", help="learning rate (default: 0.001)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=200, - metavar="N", - help="how many batches to wait before logging training status (default: 200)", - ) - - # Basic setup - args = parser.parse_args() - if not args.no_cuda and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - torch.manual_seed(args.seed) - - # Model - criterion = nn.CrossEntropyLoss() - lr = 0.001 - model = TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device) - optimizer = torch.optim.SGD(model.parameters(), lr=lr) - - # Preparing data - train_data, val_data, test_data = prepare_data(device, args.batch_size, args.test_batch_size) - - # Train - for epoch in range(1, args.epochs + 1): - train(model, train_data, device, epoch, args) - val_loss = evaluate(model, val_data, criterion) - print("-" * 89) - print(f"| end of epoch {epoch:3d} | valid loss {val_loss:5.2f} | ") - print("-" * 89) - - # Evaluate - test_loss = evaluate(model, test_data, criterion) - print("=" * 89) - print(f"| End of training | test loss {test_loss:5.2f}") - print("=" * 89) diff --git a/samples/python/training/orttrainer/pytorch_transformer/utils.py b/samples/python/training/orttrainer/pytorch_transformer/utils.py deleted file mode 100644 index 3be8b6cf3f420..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/utils.py +++ /dev/null @@ -1,59 +0,0 @@ -import os - -import torch -from torchtext.data.utils import get_tokenizer -from torchtext.utils import download_from_url, extract_archive -from torchtext.vocab import build_vocab_from_iterator - - -def batchify(data, bsz, device): - # Divide the dataset into bsz parts. - nbatch = data.size(0) // bsz - # Trim off any extra elements that wouldn't cleanly fit (remainders). - data = data.narrow(0, 0, nbatch * bsz) - # Evenly divide the data across the bsz batches. - data = data.view(bsz, -1).t().contiguous() - return data.to(device) - - -def get_batch(source, i, bptt=35): - seq_len = min(bptt, len(source) - 1 - i) - data = source[i : i + seq_len] - target = source[i + 1 : i + 1 + seq_len].view(-1) - return data, target - - -def prepare_data(device="cpu", train_batch_size=20, eval_batch_size=20, data_dir=None): - url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip" - - download_path = ".data_wikitext_2_v1" - extract_path = None - if data_dir: - download_path = os.path.join(data_dir, "download") - os.makedirs(download_path, exist_ok=True) - download_path = os.path.join(download_path, "wikitext-2-v1.zip") - - extract_path = os.path.join(data_dir, "extracted") - os.makedirs(extract_path, exist_ok=True) - - test_filepath, valid_filepath, train_filepath = extract_archive( - download_from_url(url, root=download_path), to_path=extract_path - ) - tokenizer = get_tokenizer("basic_english") - vocab = build_vocab_from_iterator(map(tokenizer, iter(open(train_filepath, encoding="utf8")))) # noqa: SIM115 - - def data_process(raw_text_iter): - data = [torch.tensor([vocab[token] for token in tokenizer(item)], dtype=torch.long) for item in raw_text_iter] - return torch.cat(tuple(filter(lambda t: t.numel() > 0, data))) - - train_data = data_process(iter(open(train_filepath, encoding="utf8"))) # noqa: SIM115 - val_data = data_process(iter(open(valid_filepath, encoding="utf8"))) # noqa: SIM115 - test_data = data_process(iter(open(test_filepath, encoding="utf8"))) # noqa: SIM115 - - device = torch.device(device) - - train_data = batchify(train_data, train_batch_size, device) - val_data = batchify(val_data, eval_batch_size, device) - test_data = batchify(test_data, eval_batch_size, device) - - return train_data, val_data, test_data diff --git a/setup.py b/setup.py index 1c04433c9a7ca..798c8c4b2895b 100644 --- a/setup.py +++ b/setup.py @@ -196,7 +196,7 @@ def run(self): "libcublasLt.so.11", "libcublasLt.so.12", "libcudart.so.11.0", - "libcudart.so.12.0", + "libcudart.so.12", "libcudnn.so.8", "libcufft.so.10", "libcufft.so.11", @@ -398,7 +398,6 @@ def finalize_options(self): "onnxruntime", "onnxruntime.backend", "onnxruntime.capi", - "onnxruntime.capi.training", "onnxruntime.datasets", "onnxruntime.tools", "onnxruntime.tools.mobile_helpers", diff --git a/tools/android_custom_build/Dockerfile b/tools/android_custom_build/Dockerfile index 66b6a36e5a8c0..754a6633b0c62 100644 --- a/tools/android_custom_build/Dockerfile +++ b/tools/android_custom_build/Dockerfile @@ -55,7 +55,7 @@ WORKDIR /workspace # install Android SDK and tools ENV ANDROID_HOME=~/android-sdk -ENV NDK_VERSION=26.0.10792818 +ENV NDK_VERSION=26.1.10909125 ENV ANDROID_NDK_HOME=${ANDROID_HOME}/ndk/${NDK_VERSION} RUN aria2c -q -d /tmp -o cmdline-tools.zip \ diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index e0559419ef8c7..6bd3e2533c045 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1171,9 +1171,9 @@ def generate_build_tree( "-Donnxruntime_USE_OPENVINO_AUTO=" + ("ON" if args.use_openvino.startswith("AUTO") else "OFF"), ] - # TensorRT and OpenVINO providers currently only support + # VitisAI and OpenVINO providers currently only support # full_protobuf option. - if args.use_full_protobuf or args.use_tensorrt or args.use_openvino or args.use_vitisai or args.gen_doc: + if args.use_full_protobuf or args.use_openvino or args.use_vitisai or args.gen_doc: cmake_args += ["-Donnxruntime_USE_FULL_PROTOBUF=ON", "-DProtobuf_USE_STATIC_LIBS=ON"] if args.use_tvm and args.llvm_path is not None: diff --git a/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml b/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml index 56f6bd56eeed7..e664cf69dec76 100644 --- a/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml +++ b/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml @@ -67,9 +67,9 @@ steps: EM_DIR: '$(Build.SourcesDirectory)/cmake/external/emsdk/upstream/emscripten' - ${{if eq(parameters.WithCache, false)}}: - - task: PythonScript@0 - displayName: '${{parameters.DisplayName}}' - inputs: - scriptPath: '$(Build.SourcesDirectory)/tools/ci_build/build.py' - arguments: ${{parameters.Arguments}} - workingDirectory: '$(Build.BinariesDirectory)' + - script: | + set -e -x + source $(Build.SourcesDirectory)/cmake/external/emsdk/emsdk_env.sh + cd '$(Build.BinariesDirectory)' + python3 '$(Build.SourcesDirectory)/tools/ci_build/build.py' ${{parameters.Arguments}} + displayName: ${{parameters.DisplayName}} diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 6fc076441108c..f2deb2041e06e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.117 + version: 1.0.118 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.117 + version: 1.0.118 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml new file mode 100644 index 0000000000000..4573c56963e34 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml @@ -0,0 +1,51 @@ +parameters: + - name: DownloadCUDA + type: boolean + default: false + - name: DownloadTRT + type: boolean + default: false + - name: CudaVersion + type: string + default: '11.8' + values: + - 11.8 + - 12.2 + +steps: + + - ${{ if eq(parameters.DownloadCUDA, true) }}: + - powershell: | + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.CudaVersion }} $(Agent.TempDirectory) + displayName: 'Download CUDA SDK v${{ parameters.CudaVersion }}' + - powershell: | + Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}\bin;$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}\extras\CUPTI\lib64" + displayName: 'Append CUDA SDK Directory to PATH' + - task: CmdLine@2 + inputs: + script: | + echo %PATH% + displayName: 'Print PATH' + + - ${{ if eq(parameters.DownloadTRT, true) }}: + - ${{ if eq(parameters.CudaVersion, '11.8') }}: + - powershell: | + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/local/TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8 $(Agent.TempDirectory) + displayName: 'Download TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8' + - powershell: | + Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8\lib" + displayName: 'Append CUDA SDK Directory to PATH' + + - ${{ if eq(parameters.CudaVersion, '12.2') }}: + - powershell: | + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/local/TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0 $(Agent.TempDirectory) + displayName: 'Download TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0' + - powershell: | + Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0\lib" + displayName: 'Append CUDA SDK Directory to PATH' + + - task: CmdLine@2 + inputs: + script: | + echo %PATH% + displayName: 'Print PATH' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index f81b1ddc8b93b..852d688b2dbb1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -90,13 +90,20 @@ jobs: arguments: --new_dir $(Build.BinariesDirectory)/deps workingDirectory: $(Build.BinariesDirectory) - - script: | - set -ex - cd '$(Build.SourcesDirectory)/cmake/external/emsdk' - ./emsdk install 3.1.44 ccache-git-emscripten-64bit - ./emsdk activate 3.1.44 ccache-git-emscripten-64bit - displayName: 'emsdk install and activate ccache for emscripten' - condition: eq('${{ parameters.WithCache }}', 'true') + - ${{if eq(parameters.WithCache, true)}}: + - script: | + set -ex + cd '$(Build.SourcesDirectory)/cmake/external/emsdk' + ./emsdk install 3.1.44 ccache-git-emscripten-64bit + ./emsdk activate 3.1.44 ccache-git-emscripten-64bit + displayName: 'emsdk install and activate ccache for emscripten' + - ${{if eq(parameters.WithCache, false)}}: + - script: | + set -ex + cd '$(Build.SourcesDirectory)/cmake/external/emsdk' + ./emsdk install 3.1.44 + ./emsdk activate 3.1.44 + displayName: 'emsdk install and activate ccache for emscripten' - template: build-linux-wasm-step.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml b/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml index 8cc7f63a193cc..b8dba89b0b899 100644 --- a/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml @@ -3,7 +3,7 @@ parameters: - name: AndroidNdkVersion type: string - default: "26.0.10792818" # LTS version + default: "26.1.10909125" # LTS version steps: - bash: | diff --git a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml index c649883ea0d8b..9982b36509b68 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml @@ -65,7 +65,6 @@ stages: clean: all steps: - checkout: self - fetchDepth: 1 submodules: false - script: | git submodule sync -- cmake/external/onnx diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh index 3bca6413100a2..da8a45e00cc90 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh @@ -19,7 +19,9 @@ fi export ONNX_ML=1 export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" +# This may install PyTorch, which will be overrided by the PyTorch local build below. /opt/python/cp39-cp39/bin/python3.9 -m pip install transformers + # beartype is installed here so that onnxscript installation step won't # install a version PyTorch doesn't like. Once beartype fixes this problem. # We can remove this line. diff --git a/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py b/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py index 113b5398f3981..9eccb7c36455f 100644 --- a/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py +++ b/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py @@ -10,9 +10,8 @@ import sys import onnx -from onnx import shape_inference -from ..onnx_model_utils import get_opsets_imported +from ..onnx_model_utils import ModelProtoWithShapeInfo, get_opsets_imported from ..reduced_build_config_parser import parse_config cpp_to_tensorproto_type = { @@ -265,15 +264,13 @@ def run_check(model_path: pathlib.Path, mobile_pkg_build_config: pathlib.Path, l ) model_file = model_path.resolve(strict=True) - model = onnx.load(str(model_file)) # we need to run shape inferencing to populate that type info for node outputs. # we will get warnings if the model uses ORT contrib ops (ONNX does not have shape inferencing for those), # and shape inferencing will be lost downstream of those. # TODO: add support for checking ORT format model as it will have full type/shape info for all nodes - model_with_type_info = shape_inference.infer_shapes(model) - - return run_check_with_model(model_with_type_info, mobile_pkg_build_config, logger) + model_wrapper = ModelProtoWithShapeInfo(model_file) + return run_check_with_model(model_wrapper.model_with_shape_info, mobile_pkg_build_config, logger) def main(): diff --git a/tools/python/util/mobile_helpers/usability_checker.py b/tools/python/util/mobile_helpers/usability_checker.py index f8b0bfe707ead..dcb3451a5e0fa 100644 --- a/tools/python/util/mobile_helpers/usability_checker.py +++ b/tools/python/util/mobile_helpers/usability_checker.py @@ -13,6 +13,7 @@ import onnx from ..onnx_model_utils import ( + ModelProtoWithShapeInfo, get_producer_consumer_maps, is_fixed_size_tensor, iterate_graph_per_graph_func, @@ -464,9 +465,9 @@ def check_shapes(graph: onnx.GraphProto, logger: Optional[logging.Logger] = None return dynamic_inputs, num_dynamic_values -def checker(model_path, logger: logging.Logger): - model = onnx.load(model_path) - model_with_shape_info = onnx.shape_inference.infer_shapes(model) +def checker(model_path: pathlib.Path, logger: logging.Logger): + model_with_shape_info_wrapper = ModelProtoWithShapeInfo(model_path) + model_with_shape_info = model_with_shape_info_wrapper.model_with_shape_info # create lookup map for efficiency value_to_shape = {} @@ -541,10 +542,10 @@ def analyze_model(model_path: pathlib.Path, skip_optimize: bool = False, logger: with tempfile.TemporaryDirectory() as tmp: if not skip_optimize: tmp_path = pathlib.Path(tmp) / model_path.name - optimize_model(model_path, tmp_path) + optimize_model(model_path, tmp_path, use_external_initializers=True) model_path = tmp_path - try_eps = checker(str(model_path.resolve(strict=True)), logger) + try_eps = checker(model_path.resolve(strict=True), logger) return try_eps diff --git a/tools/python/util/onnx_model_utils.py b/tools/python/util/onnx_model_utils.py index e662d1623f8bd..5c970430a3a82 100644 --- a/tools/python/util/onnx_model_utils.py +++ b/tools/python/util/onnx_model_utils.py @@ -95,6 +95,7 @@ def optimize_model( output_path: pathlib.Path, level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC, log_level: int = 3, + use_external_initializers: bool = False, ): """ Optimize an ONNX model using ONNX Runtime to the specified level @@ -103,12 +104,25 @@ def optimize_model( :param level: onnxruntime.GraphOptimizationLevel to use. Default is ORT_ENABLE_BASIC. :param log_level: Log level. Defaults to Error (3) so we don't get output about unused initializers being removed. Warning (2) or Info (1) may be desirable in some scenarios. + :param use_external_initializers: Set flag to write initializers to an external file. Required if model > 2GB. + Requires onnxruntime 1.17+ """ so = ort.SessionOptions() so.optimized_model_filepath = str(output_path.resolve()) so.graph_optimization_level = level so.log_severity_level = log_level + # save using external initializers so models > 2 GB are handled + if use_external_initializers: + major, minor, rest = ort.__version__.split(".", 3) + if (int(major), int(minor)) >= (1, 17): + so.add_session_config_entry("session.optimized_model_external_initializers_file_name", "external_data.pb") + else: + raise ValueError( + "ONNX Runtime 1.17 or higher required to save initializers as external data when optimizing model. " + f"Current ONNX Runtime version is {ort.__version__}" + ) + # create session to optimize. this will write the updated model to output_path _ = ort.InferenceSession(str(model_path.resolve(strict=True)), so, providers=["CPUExecutionProvider"]) @@ -366,3 +380,34 @@ def get_optimization_level(level): return ort.GraphOptimizationLevel.ORT_ENABLE_ALL raise ValueError("Invalid optimization level of " + level) + + +class ModelProtoWithShapeInfo: + """ + Class to load an ONNX model and run shape inferencing on it to populate the ValueInfo. + The model_with_shape_info property will contain the updated model. + If the model is > 2GB and uses external data a temporary file is required to run shape inferencing successfully. + This helper class handles automatic removal of the temporary file. + """ + + def __init__(self, model_path: pathlib.Path): + """ + :param model_path: Path to ONNX model to load and run shape inferencing on. + """ + + self.model_path = model_path + + model = onnx.load(str(model_path)) + self.model_with_shape_info = onnx.shape_inference.infer_shapes(model, strict_mode=True) + + # ONNX has a silent failure from the call to infer_shapes when the model is > 2GB. + # We detect that by checking the nodes in the returned model. + self._tmp_model_path = None + if len(model.graph.node) > 0 and len(self.model_with_shape_info.graph.node) == 0: + self._tmp_model_path = pathlib.Path(model_path).with_suffix(".temp_with_shapeinf.onnx") + onnx.shape_inference.infer_shapes_path(str(model_path), str(self._tmp_model_path), strict_mode=True) + self.model_with_shape_info = onnx.load(str(self._tmp_model_path)) + + def __del__(self): + if self._tmp_model_path: + self._tmp_model_path.unlink(missing_ok=True)