diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index e82219a0aff64..5796db03fed7c 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -114,9 +114,7 @@ option(onnxruntime_ENABLE_LTO "Enable link time optimization" OFF) option(onnxruntime_CROSS_COMPILING "Cross compiling onnx runtime" OFF) option(onnxruntime_GCOV_COVERAGE "Compile with options necessary to run code coverage" OFF) option(onnxruntime_DONT_VECTORIZE "Do not vectorize operations in Eigen" OFF) - -#It's preferred to turn it OFF when onnxruntime is dynamically linked to PROTOBUF. But Tensort always required the full version of protobuf. -cmake_dependent_option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF "NOT onnxruntime_USE_TENSORRT" ON) +option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF) option(tensorflow_C_PACKAGE_PATH "Path to tensorflow C package installation dir") option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in language other than cpp" OFF) option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump debug information about node inputs and outputs when executing the model." OFF) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index a9a78668b4810..345ef2b504aa4 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -339,9 +339,6 @@ configure_file(${ONNXRUNTIME_ROOT}/python/_pybind_state.py.in ${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py) if (onnxruntime_ENABLE_TRAINING) - file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/python/deprecated/*.py" - ) file(GLOB onnxruntime_python_root_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/*.py" ) @@ -419,10 +416,6 @@ if (onnxruntime_ENABLE_TRAINING) "${ORTTRAINING_SOURCE_DIR}/python/training/onnxblock/optim/*" ) endif() -else() - file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/python/training/*.py" - ) endif() if (onnxruntime_BUILD_UNIT_TESTS) @@ -443,6 +436,9 @@ if (onnxruntime_BUILD_UNIT_TESTS) file(GLOB onnxruntime_python_transformers_testdata_whisper CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/whisper/*.onnx" ) + file(GLOB onnxruntime_python_transformers_testdata_conformer CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/conformer/*.onnx" + ) endif() file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS @@ -556,6 +552,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models/whisper COMMAND ${CMAKE_COMMAND} -E make_directory $/eager_test + COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models/conformer COMMAND ${CMAKE_COMMAND} -E copy ${ONNXRUNTIME_ROOT}/__init__.py $/onnxruntime/ @@ -577,9 +574,6 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py $/onnxruntime/capi/ - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_capi_training_srcs} - $/onnxruntime/capi/training/ COMMAND ${CMAKE_COMMAND} -E copy $ $/onnxruntime/capi/ @@ -711,6 +705,9 @@ if (onnxruntime_BUILD_UNIT_TESTS) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_testdata_whisper} $/transformers/test_data/models/whisper/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_testdata_conformer} + $/transformers/test_data/models/conformer/ ) endif() @@ -750,9 +747,6 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils/data/ COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils/hooks/ - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_capi_training_srcs} - $/onnxruntime/capi/training/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_root_srcs} $/onnxruntime/training/ diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index bdb0230a8ebd0..a52e941b235b4 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -906,7 +906,7 @@ if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") set_target_properties(onnxruntime_test_all PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js) - set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s ALLOW_MEMORY_GROWTH=1 --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1 -s DEMANGLE_SUPPORT=1") + set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=4294967296 --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1 -s DEMANGLE_SUPPORT=1") if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY LINK_FLAGS " -s DEFAULT_PTHREAD_STACK_SIZE=131072 -s PROXY_TO_PTHREAD=1") endif() diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 9c31978c66486..c73f978bdf404 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2385,7 +2385,7 @@ This version of the operator has been available since version 1 of the 'com.micr Group Query Self/Cross Attention. - Supports different number of heads for q and kv. + Supports different number of heads for q and kv. Only supports causal or local attention. #### Version @@ -2396,6 +2396,8 @@ This version of the operator has been available since version 1 of the 'com.micr
kv_num_heads : int (required)
Number of attention heads for k and v
+
local_window_size : int
+
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
num_heads : int (required)
Number of attention heads for q
scale : float
@@ -2647,8 +2649,8 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T1 : tensor(float), tensor(float16)
-
Constrain input and output types to float/half_float tensors.
+
T1 : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float/half_float/brain_float tensors.
T2 : tensor(uint8)
Constrain quantized weight types to uint8.
@@ -5021,7 +5023,7 @@ This version of the operator has been available since version 1 of the 'com.micr
input : T
-
3D tensor with shape (batch_size, sequence_length, hidden_size)
+
3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)
position_ids : M
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
cos_cache : T
@@ -5034,7 +5036,7 @@ This version of the operator has been available since version 1 of the 'com.micr
output : T
-
3D tensor with shape (batch_size, sequence_length, hidden_size)
+
tensor with same shape as input.
#### Type Constraints diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 26b5ebbdbec36..16df788c284ee 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -840,7 +840,7 @@ Do not modify directly.* |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| -|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc2_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h index 443710884743a..0c0af16d4e20c 100644 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -399,6 +399,15 @@ struct TensorArray : public ArgBase { using Variadic = TensorArray; +/* +Note: +OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core. +The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so: +1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierachy. +2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp, + hence memory could still be recycled properly. +Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety. +*/ struct OrtLiteCustomOp : public OrtCustomOp { using ConstOptionalFloatTensor = std::optional&>; using OptionalFloatTensor = std::optional>; @@ -774,10 +783,13 @@ struct OrtLiteCustomOp : public OrtCustomOp { OrtLiteCustomOp(const char* op_name, const char* execution_provider, - int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), - execution_provider_(execution_provider), - start_ver_(start_ver), - end_ver_(end_ver) { + ShapeInferFn shape_infer_fn, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), + execution_provider_(execution_provider), + shape_infer_fn_(shape_infer_fn), + start_ver_(start_ver), + end_ver_(end_ver) { OrtCustomOp::version = ORT_API_VERSION; OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast(op)->op_name_.c_str(); }; @@ -858,8 +870,13 @@ struct OrtLiteCustomOp : public OrtCustomOp { std::vector input_types_; std::vector output_types_; + ShapeInferFn shape_infer_fn_ = {}; + int start_ver_ = 1; int end_ver_ = MAX_CUSTOM_OP_END_VER; + + void* compute_fn_ = {}; + void* compute_fn_return_status_ = {}; }; //////////////////////////// OrtLiteCustomFunc //////////////////////////////// @@ -891,9 +908,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { ComputeFn compute_fn, ShapeInferFn shape_infer_fn = {}, int start_ver = 1, - int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), - compute_fn_(compute_fn), - shape_infer_fn_(shape_infer_fn) { + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) { + compute_fn_ = reinterpret_cast(compute_fn); ParseArgs(input_types_, output_types_); OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { @@ -905,7 +921,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique(); - kernel->compute_fn_ = static_cast(this_)->compute_fn_; + auto me = static_cast(this_); + kernel->compute_fn_ = reinterpret_cast(me->compute_fn_); Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); auto self = static_cast(this_); @@ -931,9 +948,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { ComputeFnReturnStatus compute_fn_return_status, ShapeInferFn shape_infer_fn = {}, int start_ver = 1, - int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), - compute_fn_return_status_(compute_fn_return_status), - shape_infer_fn_(shape_infer_fn) { + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) { + compute_fn_return_status_ = reinterpret_cast(compute_fn_return_status); ParseArgs(input_types_, output_types_); OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { @@ -945,7 +961,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique(); - kernel->compute_fn_return_status_ = static_cast(this_)->compute_fn_return_status_; + auto me = static_cast(this_); + kernel->compute_fn_return_status_ = reinterpret_cast(me->compute_fn_return_status_); Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); auto self = static_cast(this_); @@ -965,10 +982,6 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { }; } } - - ComputeFn compute_fn_ = {}; - ComputeFnReturnStatus compute_fn_return_status_ = {}; - ShapeInferFn shape_infer_fn_ = {}; }; // struct OrtLiteCustomFunc /////////////////////////// OrtLiteCustomStruct /////////////////////////// @@ -1007,7 +1020,7 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { OrtLiteCustomStruct(const char* op_name, const char* execution_provider, int start_ver = 1, - int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver) { + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) { SetCompute(&CustomOp::Compute); OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { diff --git a/js/.eslintrc.js b/js/.eslintrc.js index fd30cb96a5bd0..0bf47c5264f61 100644 --- a/js/.eslintrc.js +++ b/js/.eslintrc.js @@ -5,10 +5,18 @@ module.exports = { root: true, - ignorePatterns: ['**/*.js', 'ort-schema/', 'common/test/type-tests/', 'test/data/', 'node_modules/', 'dist/'], + ignorePatterns: [ + '**/*.js', + 'node_modules/', + 'ort-schema/', + 'common/test/type-tests/', + 'web/types.d.ts', + 'test/data/', + 'dist/', + ], env: { 'es6': true }, parser: '@typescript-eslint/parser', - parserOptions: { 'project': 'tsconfig.json', 'sourceType': 'module' }, + parserOptions: { 'project': true, 'sourceType': 'module' }, plugins: ['@typescript-eslint', 'prefer-arrow', 'header', 'import', 'unicorn', 'jsdoc'], rules: { 'unicorn/filename-case': 'error', @@ -144,15 +152,56 @@ module.exports = { 'no-unused-expressions': 'off', } }, { - files: ['web/lib/**/*.ts'], - excludedFiles: 'web/lib/wasm/proxy-worker/**/*', - parserOptions: { 'project': 'web/tsconfig.json' }, - rules: { - 'no-underscore-dangle': 'off', + files: ['web/lib/**/*.ts'], rules: { + 'no-underscore-dangle': ['error', { + 'allow': [ + '_free', + '_malloc', + '_JsepGetNodeName', + '_JsepOutput', + '_OrtAddFreeDimensionOverride', + '_OrtAddRunConfigEntry', + '_OrtAddSessionConfigEntry', + '_OrtAppendExecutionProvider', + '_OrtBindInput', + '_OrtBindOutput', + '_OrtClearBoundOutputs', + '_OrtCreateBinding', + '_OrtCreateRunOptions', + '_OrtCreateSession', + '_OrtCreateSessionOptions', + '_OrtCreateTensor', + '_OrtEndProfiling', + '_OrtFree', + '_OrtGetInputName', + '_OrtGetInputOutputCount', + '_OrtGetLastError', + '_OrtGetOutputName', + '_OrtGetTensorData', + '_OrtInit', + '_OrtReleaseBinding', + '_OrtReleaseRunOptions', + '_OrtReleaseSession', + '_OrtReleaseSessionOptions', + '_OrtReleaseTensor', + '_OrtRun', + '_OrtRunWithBinding', + '_OrtTrainingCopyParametersFromBuffer', + '_OrtTrainingCopyParametersToBuffer', + '_OrtTrainingCreateSession', + '_OrtTrainingEvalStep', + '_OrtTrainingGetModelInputOutputCount', + '_OrtTrainingGetModelInputOutputName', + '_OrtTrainingGetParametersSize', + '_OrtTrainingLazyResetGrad', + '_OrtTrainingLoadCheckpoint', + '_OrtTrainingOptimizerStep', + '_OrtTrainingReleaseCheckpoint', + '_OrtTrainingReleaseSession', + '_OrtTrainingRunTrainStep' + ] + }] } - }, { - files: ['web/lib/wasm/proxy-worker/**/*.ts'], - parserOptions: { 'project': 'web/lib/wasm/proxy-worker/tsconfig.json' }, }, { files: ['web/lib/onnxjs/**/*.ts'], rules: { // TODO: those rules are useful. should turn on them in future (webgl refactor) @@ -164,6 +213,7 @@ module.exports = { 'import/no-internal-modules': 'off', 'prefer-arrow/prefer-arrow-functions': 'off', 'no-param-reassign': 'off', + 'no-underscore-dangle': 'off', 'guard-for-in': 'off' } }, { diff --git a/js/node/package-lock.json b/js/node/package-lock.json index e8968bafc4a9f..c1cf8af4bb80e 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -22,7 +22,7 @@ "jsonc": "^2.0.0", "minimist": "^1.2.8", "node-addon-api": "^6.0.0", - "onnx-proto": "^8.0.1" + "protobufjs": "^7.2.4" } }, "../common": { @@ -97,12 +97,6 @@ "integrity": "sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==", "dev": true }, - "node_modules/@types/long": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/@types/long/-/long-4.0.2.tgz", - "integrity": "sha512-MqTGEo5bj5t157U6fA/BiDynNkn0YknVdh48CMPkTSpFTVmvao5UQmm7uEF6xBEo7qIMAlY/JSleYaE6VOdpaA==", - "dev": true - }, "node_modules/@types/minimist": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/@types/minimist/-/minimist-1.2.2.tgz", @@ -528,9 +522,9 @@ "dev": true }, "node_modules/long": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz", - "integrity": "sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", + "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==", "dev": true }, "node_modules/lru-cache": { @@ -663,15 +657,6 @@ "node": "^12.13.0 || ^14.15.0 || >=16.0.0" } }, - "node_modules/onnx-proto": { - "version": "8.0.1", - "resolved": "https://registry.npmjs.org/onnx-proto/-/onnx-proto-8.0.1.tgz", - "integrity": "sha512-ZpPTqp5dneh2bvavk/QpDsf20JJRArjqTkiMfshGmxR8ocjmfTk80fkW00FwLO7qRtybo9NPugcWQrumHYctLQ==", - "dev": true, - "dependencies": { - "protobufjs": "^6.11.2" - } - }, "node_modules/onnxruntime-common": { "resolved": "../common", "link": true @@ -690,9 +675,9 @@ } }, "node_modules/protobufjs": { - "version": "6.11.4", - "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-6.11.4.tgz", - "integrity": "sha512-5kQWPaJHi1WoCpjTGszzQ32PG2F4+wRY6BmAT4Vfw56Q2FZ4YZzK20xUYQH4YkfehY1e6QSICrJquM6xXZNcrw==", + "version": "7.2.5", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.5.tgz", + "integrity": "sha512-gGXRSXvxQ7UiPgfw8gevrfRWcTlSbOFg+p/N+JVJEK5VhueL2miT6qTymqAmjr1Q5WbOCyJbyrk6JfWKwlFn6A==", "dev": true, "hasInstallScript": true, "dependencies": { @@ -706,13 +691,11 @@ "@protobufjs/path": "^1.1.2", "@protobufjs/pool": "^1.1.0", "@protobufjs/utf8": "^1.1.0", - "@types/long": "^4.0.1", "@types/node": ">=13.7.0", - "long": "^4.0.0" + "long": "^5.0.0" }, - "bin": { - "pbjs": "bin/pbjs", - "pbts": "bin/pbts" + "engines": { + "node": ">=12.0.0" } }, "node_modules/proxy-from-env": { @@ -789,9 +772,9 @@ ] }, "node_modules/semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "dependencies": { "lru-cache": "^6.0.0" @@ -1070,12 +1053,6 @@ "integrity": "sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==", "dev": true }, - "@types/long": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/@types/long/-/long-4.0.2.tgz", - "integrity": "sha512-MqTGEo5bj5t157U6fA/BiDynNkn0YknVdh48CMPkTSpFTVmvao5UQmm7uEF6xBEo7qIMAlY/JSleYaE6VOdpaA==", - "dev": true - }, "@types/minimist": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/@types/minimist/-/minimist-1.2.2.tgz", @@ -1413,9 +1390,9 @@ "dev": true }, "long": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz", - "integrity": "sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", + "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==", "dev": true }, "lru-cache": { @@ -1523,15 +1500,6 @@ "set-blocking": "^2.0.0" } }, - "onnx-proto": { - "version": "8.0.1", - "resolved": "https://registry.npmjs.org/onnx-proto/-/onnx-proto-8.0.1.tgz", - "integrity": "sha512-ZpPTqp5dneh2bvavk/QpDsf20JJRArjqTkiMfshGmxR8ocjmfTk80fkW00FwLO7qRtybo9NPugcWQrumHYctLQ==", - "dev": true, - "requires": { - "protobufjs": "^6.11.2" - } - }, "onnxruntime-common": { "version": "file:../common", "requires": { @@ -1549,9 +1517,9 @@ } }, "protobufjs": { - "version": "6.11.4", - "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-6.11.4.tgz", - "integrity": "sha512-5kQWPaJHi1WoCpjTGszzQ32PG2F4+wRY6BmAT4Vfw56Q2FZ4YZzK20xUYQH4YkfehY1e6QSICrJquM6xXZNcrw==", + "version": "7.2.5", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.5.tgz", + "integrity": "sha512-gGXRSXvxQ7UiPgfw8gevrfRWcTlSbOFg+p/N+JVJEK5VhueL2miT6qTymqAmjr1Q5WbOCyJbyrk6JfWKwlFn6A==", "dev": true, "requires": { "@protobufjs/aspromise": "^1.1.2", @@ -1564,9 +1532,8 @@ "@protobufjs/path": "^1.1.2", "@protobufjs/pool": "^1.1.0", "@protobufjs/utf8": "^1.1.0", - "@types/long": "^4.0.1", "@types/node": ">=13.7.0", - "long": "^4.0.0" + "long": "^5.0.0" } }, "proxy-from-env": { @@ -1619,9 +1586,9 @@ "dev": true }, "semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "requires": { "lru-cache": "^6.0.0" diff --git a/js/node/package.json b/js/node/package.json index 0f8f0e9d2260c..8e591d8f46b9d 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -19,6 +19,7 @@ }, "scripts": { "buildr": "tsc && node ./script/build --config=RelWithDebInfo", + "preprepare": "node -e \"require('node:fs').copyFileSync('./node_modules/long/index.d.ts', './node_modules/long/umd/index.d.ts')\"", "prepare": "tsc --build script test .", "rebuild": "tsc && node ./script/build --rebuild", "rebuildd": "tsc && node ./script/build --rebuild --config=Debug", @@ -39,7 +40,7 @@ "jsonc": "^2.0.0", "minimist": "^1.2.8", "node-addon-api": "^6.0.0", - "onnx-proto": "^8.0.1" + "protobufjs": "^7.2.4" }, "main": "dist/index.js", "os": [ diff --git a/js/node/test/ort-schema/protobuf/.gitignore b/js/node/test/ort-schema/protobuf/.gitignore new file mode 100644 index 0000000000000..092bb6c1c9fb4 --- /dev/null +++ b/js/node/test/ort-schema/protobuf/.gitignore @@ -0,0 +1,2 @@ +!onnx.js +!onnx.d.ts diff --git a/js/node/test/ort-schema/protobuf/README.md b/js/node/test/ort-schema/protobuf/README.md new file mode 100644 index 0000000000000..f5f52c602f1ad --- /dev/null +++ b/js/node/test/ort-schema/protobuf/README.md @@ -0,0 +1,21 @@ +# ONNX protobuf + +This directory contains generated protobuf definition for onnx: + +- onnx.js +- onnx.d.ts + +These files are generated from [a fork of onnx-proto](https://github.com/fs-eire/onnx-proto/tree/update-v9). + +The ONNX protobuf uses protobufjs@7.2.4, which depends on long@5.2.3, the version contains 2 bugs: + +- type export does not work with commonjs. described in https://github.com/dcodeIO/long.js/pull/124. added a "postinstall" script to fix. +- in the generated typescript declaration file 'onnx.d.ts', the following line: + ```ts + import Long = require("long"); + ``` + need to be replaced to fix type import error: + ```ts + import Long from "long"; + ``` + this replacement is done and code format is also applied to file 'onnx.d.ts'. diff --git a/js/node/test/ort-schema/protobuf/onnx.d.ts b/js/node/test/ort-schema/protobuf/onnx.d.ts new file mode 100644 index 0000000000000..c60264dca2a8d --- /dev/null +++ b/js/node/test/ort-schema/protobuf/onnx.d.ts @@ -0,0 +1,2627 @@ +import Long from 'long'; +import * as $protobuf from 'protobufjs'; + +/** Namespace onnx. */ +export namespace onnx { + + /** Version enum. */ + enum Version { + _START_VERSION = 0, + IR_VERSION_2017_10_10 = 1, + IR_VERSION_2017_10_30 = 2, + IR_VERSION_2017_11_3 = 3, + IR_VERSION_2019_1_22 = 4, + IR_VERSION_2019_3_18 = 5, + IR_VERSION_2019_9_19 = 6, + IR_VERSION_2020_5_8 = 7, + IR_VERSION_2021_7_30 = 8, + IR_VERSION = 9 + } + + /** Properties of an AttributeProto. */ + interface IAttributeProto { + /** AttributeProto name */ + name?: (string|null); + + /** AttributeProto refAttrName */ + refAttrName?: (string|null); + + /** AttributeProto docString */ + docString?: (string|null); + + /** AttributeProto type */ + type?: (onnx.AttributeProto.AttributeType|null); + + /** AttributeProto f */ + f?: (number|null); + + /** AttributeProto i */ + i?: (number|Long|null); + + /** AttributeProto s */ + s?: (Uint8Array|null); + + /** AttributeProto t */ + t?: (onnx.ITensorProto|null); + + /** AttributeProto g */ + g?: (onnx.IGraphProto|null); + + /** AttributeProto sparseTensor */ + sparseTensor?: (onnx.ISparseTensorProto|null); + + /** AttributeProto tp */ + tp?: (onnx.ITypeProto|null); + + /** AttributeProto floats */ + floats?: (number[]|null); + + /** AttributeProto ints */ + ints?: ((number | Long)[]|null); + + /** AttributeProto strings */ + strings?: (Uint8Array[]|null); + + /** AttributeProto tensors */ + tensors?: (onnx.ITensorProto[]|null); + + /** AttributeProto graphs */ + graphs?: (onnx.IGraphProto[]|null); + + /** AttributeProto sparseTensors */ + sparseTensors?: (onnx.ISparseTensorProto[]|null); + + /** AttributeProto typeProtos */ + typeProtos?: (onnx.ITypeProto[]|null); + } + + /** Represents an AttributeProto. */ + class AttributeProto implements IAttributeProto { + /** + * Constructs a new AttributeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IAttributeProto); + + /** AttributeProto name. */ + public name: string; + + /** AttributeProto refAttrName. */ + public refAttrName: string; + + /** AttributeProto docString. */ + public docString: string; + + /** AttributeProto type. */ + public type: onnx.AttributeProto.AttributeType; + + /** AttributeProto f. */ + public f: number; + + /** AttributeProto i. */ + public i: (number|Long); + + /** AttributeProto s. */ + public s: Uint8Array; + + /** AttributeProto t. */ + public t?: (onnx.ITensorProto|null); + + /** AttributeProto g. */ + public g?: (onnx.IGraphProto|null); + + /** AttributeProto sparseTensor. */ + public sparseTensor?: (onnx.ISparseTensorProto|null); + + /** AttributeProto tp. */ + public tp?: (onnx.ITypeProto|null); + + /** AttributeProto floats. */ + public floats: number[]; + + /** AttributeProto ints. */ + public ints: (number|Long)[]; + + /** AttributeProto strings. */ + public strings: Uint8Array[]; + + /** AttributeProto tensors. */ + public tensors: onnx.ITensorProto[]; + + /** AttributeProto graphs. */ + public graphs: onnx.IGraphProto[]; + + /** AttributeProto sparseTensors. */ + public sparseTensors: onnx.ISparseTensorProto[]; + + /** AttributeProto typeProtos. */ + public typeProtos: onnx.ITypeProto[]; + + /** + * Creates a new AttributeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns AttributeProto instance + */ + public static create(properties?: onnx.IAttributeProto): onnx.AttributeProto; + + /** + * Encodes the specified AttributeProto message. Does not implicitly {@link onnx.AttributeProto.verify|verify} + * messages. + * @param message AttributeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IAttributeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified AttributeProto message, length delimited. Does not implicitly {@link + * onnx.AttributeProto.verify|verify} messages. + * @param message AttributeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IAttributeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes an AttributeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.AttributeProto; + + /** + * Decodes an AttributeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.AttributeProto; + + /** + * Verifies an AttributeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates an AttributeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns AttributeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.AttributeProto; + + /** + * Creates a plain object from an AttributeProto message. Also converts values to other types if specified. + * @param message AttributeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.AttributeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this AttributeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for AttributeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace AttributeProto { + + /** AttributeType enum. */ + enum AttributeType { + UNDEFINED = 0, + FLOAT = 1, + INT = 2, + STRING = 3, + TENSOR = 4, + GRAPH = 5, + SPARSE_TENSOR = 11, + TYPE_PROTO = 13, + FLOATS = 6, + INTS = 7, + STRINGS = 8, + TENSORS = 9, + GRAPHS = 10, + SPARSE_TENSORS = 12, + TYPE_PROTOS = 14 + } + } + + /** Properties of a ValueInfoProto. */ + interface IValueInfoProto { + /** ValueInfoProto name */ + name?: (string|null); + + /** ValueInfoProto type */ + type?: (onnx.ITypeProto|null); + + /** ValueInfoProto docString */ + docString?: (string|null); + } + + /** Represents a ValueInfoProto. */ + class ValueInfoProto implements IValueInfoProto { + /** + * Constructs a new ValueInfoProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IValueInfoProto); + + /** ValueInfoProto name. */ + public name: string; + + /** ValueInfoProto type. */ + public type?: (onnx.ITypeProto|null); + + /** ValueInfoProto docString. */ + public docString: string; + + /** + * Creates a new ValueInfoProto instance using the specified properties. + * @param [properties] Properties to set + * @returns ValueInfoProto instance + */ + public static create(properties?: onnx.IValueInfoProto): onnx.ValueInfoProto; + + /** + * Encodes the specified ValueInfoProto message. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} + * messages. + * @param message ValueInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IValueInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified ValueInfoProto message, length delimited. Does not implicitly {@link + * onnx.ValueInfoProto.verify|verify} messages. + * @param message ValueInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IValueInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.ValueInfoProto; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.ValueInfoProto; + + /** + * Verifies a ValueInfoProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a ValueInfoProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns ValueInfoProto + */ + public static fromObject(object: {[k: string]: any}): onnx.ValueInfoProto; + + /** + * Creates a plain object from a ValueInfoProto message. Also converts values to other types if specified. + * @param message ValueInfoProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.ValueInfoProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this ValueInfoProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for ValueInfoProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a NodeProto. */ + interface INodeProto { + /** NodeProto input */ + input?: (string[]|null); + + /** NodeProto output */ + output?: (string[]|null); + + /** NodeProto name */ + name?: (string|null); + + /** NodeProto opType */ + opType?: (string|null); + + /** NodeProto domain */ + domain?: (string|null); + + /** NodeProto attribute */ + attribute?: (onnx.IAttributeProto[]|null); + + /** NodeProto docString */ + docString?: (string|null); + } + + /** Represents a NodeProto. */ + class NodeProto implements INodeProto { + /** + * Constructs a new NodeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.INodeProto); + + /** NodeProto input. */ + public input: string[]; + + /** NodeProto output. */ + public output: string[]; + + /** NodeProto name. */ + public name: string; + + /** NodeProto opType. */ + public opType: string; + + /** NodeProto domain. */ + public domain: string; + + /** NodeProto attribute. */ + public attribute: onnx.IAttributeProto[]; + + /** NodeProto docString. */ + public docString: string; + + /** + * Creates a new NodeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns NodeProto instance + */ + public static create(properties?: onnx.INodeProto): onnx.NodeProto; + + /** + * Encodes the specified NodeProto message. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @param message NodeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.INodeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified NodeProto message, length delimited. Does not implicitly {@link + * onnx.NodeProto.verify|verify} messages. + * @param message NodeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.INodeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a NodeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.NodeProto; + + /** + * Decodes a NodeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.NodeProto; + + /** + * Verifies a NodeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a NodeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns NodeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.NodeProto; + + /** + * Creates a plain object from a NodeProto message. Also converts values to other types if specified. + * @param message NodeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.NodeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this NodeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for NodeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TrainingInfoProto. */ + interface ITrainingInfoProto { + /** TrainingInfoProto initialization */ + initialization?: (onnx.IGraphProto|null); + + /** TrainingInfoProto algorithm */ + algorithm?: (onnx.IGraphProto|null); + + /** TrainingInfoProto initializationBinding */ + initializationBinding?: (onnx.IStringStringEntryProto[]|null); + + /** TrainingInfoProto updateBinding */ + updateBinding?: (onnx.IStringStringEntryProto[]|null); + } + + /** Represents a TrainingInfoProto. */ + class TrainingInfoProto implements ITrainingInfoProto { + /** + * Constructs a new TrainingInfoProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITrainingInfoProto); + + /** TrainingInfoProto initialization. */ + public initialization?: (onnx.IGraphProto|null); + + /** TrainingInfoProto algorithm. */ + public algorithm?: (onnx.IGraphProto|null); + + /** TrainingInfoProto initializationBinding. */ + public initializationBinding: onnx.IStringStringEntryProto[]; + + /** TrainingInfoProto updateBinding. */ + public updateBinding: onnx.IStringStringEntryProto[]; + + /** + * Creates a new TrainingInfoProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TrainingInfoProto instance + */ + public static create(properties?: onnx.ITrainingInfoProto): onnx.TrainingInfoProto; + + /** + * Encodes the specified TrainingInfoProto message. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} + * messages. + * @param message TrainingInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITrainingInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TrainingInfoProto message, length delimited. Does not implicitly {@link + * onnx.TrainingInfoProto.verify|verify} messages. + * @param message TrainingInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITrainingInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TrainingInfoProto; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TrainingInfoProto; + + /** + * Verifies a TrainingInfoProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TrainingInfoProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TrainingInfoProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TrainingInfoProto; + + /** + * Creates a plain object from a TrainingInfoProto message. Also converts values to other types if specified. + * @param message TrainingInfoProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TrainingInfoProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TrainingInfoProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TrainingInfoProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a ModelProto. */ + interface IModelProto { + /** ModelProto irVersion */ + irVersion?: (number|Long|null); + + /** ModelProto opsetImport */ + opsetImport?: (onnx.IOperatorSetIdProto[]|null); + + /** ModelProto producerName */ + producerName?: (string|null); + + /** ModelProto producerVersion */ + producerVersion?: (string|null); + + /** ModelProto domain */ + domain?: (string|null); + + /** ModelProto modelVersion */ + modelVersion?: (number|Long|null); + + /** ModelProto docString */ + docString?: (string|null); + + /** ModelProto graph */ + graph?: (onnx.IGraphProto|null); + + /** ModelProto metadataProps */ + metadataProps?: (onnx.IStringStringEntryProto[]|null); + + /** ModelProto trainingInfo */ + trainingInfo?: (onnx.ITrainingInfoProto[]|null); + + /** ModelProto functions */ + functions?: (onnx.IFunctionProto[]|null); + } + + /** Represents a ModelProto. */ + class ModelProto implements IModelProto { + /** + * Constructs a new ModelProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IModelProto); + + /** ModelProto irVersion. */ + public irVersion: (number|Long); + + /** ModelProto opsetImport. */ + public opsetImport: onnx.IOperatorSetIdProto[]; + + /** ModelProto producerName. */ + public producerName: string; + + /** ModelProto producerVersion. */ + public producerVersion: string; + + /** ModelProto domain. */ + public domain: string; + + /** ModelProto modelVersion. */ + public modelVersion: (number|Long); + + /** ModelProto docString. */ + public docString: string; + + /** ModelProto graph. */ + public graph?: (onnx.IGraphProto|null); + + /** ModelProto metadataProps. */ + public metadataProps: onnx.IStringStringEntryProto[]; + + /** ModelProto trainingInfo. */ + public trainingInfo: onnx.ITrainingInfoProto[]; + + /** ModelProto functions. */ + public functions: onnx.IFunctionProto[]; + + /** + * Creates a new ModelProto instance using the specified properties. + * @param [properties] Properties to set + * @returns ModelProto instance + */ + public static create(properties?: onnx.IModelProto): onnx.ModelProto; + + /** + * Encodes the specified ModelProto message. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @param message ModelProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IModelProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified ModelProto message, length delimited. Does not implicitly {@link + * onnx.ModelProto.verify|verify} messages. + * @param message ModelProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IModelProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a ModelProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.ModelProto; + + /** + * Decodes a ModelProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.ModelProto; + + /** + * Verifies a ModelProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a ModelProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns ModelProto + */ + public static fromObject(object: {[k: string]: any}): onnx.ModelProto; + + /** + * Creates a plain object from a ModelProto message. Also converts values to other types if specified. + * @param message ModelProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.ModelProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this ModelProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for ModelProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a StringStringEntryProto. */ + interface IStringStringEntryProto { + /** StringStringEntryProto key */ + key?: (string|null); + + /** StringStringEntryProto value */ + value?: (string|null); + } + + /** Represents a StringStringEntryProto. */ + class StringStringEntryProto implements IStringStringEntryProto { + /** + * Constructs a new StringStringEntryProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IStringStringEntryProto); + + /** StringStringEntryProto key. */ + public key: string; + + /** StringStringEntryProto value. */ + public value: string; + + /** + * Creates a new StringStringEntryProto instance using the specified properties. + * @param [properties] Properties to set + * @returns StringStringEntryProto instance + */ + public static create(properties?: onnx.IStringStringEntryProto): onnx.StringStringEntryProto; + + /** + * Encodes the specified StringStringEntryProto message. Does not implicitly {@link + * onnx.StringStringEntryProto.verify|verify} messages. + * @param message StringStringEntryProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IStringStringEntryProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified StringStringEntryProto message, length delimited. Does not implicitly {@link + * onnx.StringStringEntryProto.verify|verify} messages. + * @param message StringStringEntryProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IStringStringEntryProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.StringStringEntryProto; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.StringStringEntryProto; + + /** + * Verifies a StringStringEntryProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a StringStringEntryProto message from a plain object. Also converts values to their respective internal + * types. + * @param object Plain object + * @returns StringStringEntryProto + */ + public static fromObject(object: {[k: string]: any}): onnx.StringStringEntryProto; + + /** + * Creates a plain object from a StringStringEntryProto message. Also converts values to other types if specified. + * @param message StringStringEntryProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.StringStringEntryProto, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this StringStringEntryProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for StringStringEntryProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TensorAnnotation. */ + interface ITensorAnnotation { + /** TensorAnnotation tensorName */ + tensorName?: (string|null); + + /** TensorAnnotation quantParameterTensorNames */ + quantParameterTensorNames?: (onnx.IStringStringEntryProto[]|null); + } + + /** Represents a TensorAnnotation. */ + class TensorAnnotation implements ITensorAnnotation { + /** + * Constructs a new TensorAnnotation. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITensorAnnotation); + + /** TensorAnnotation tensorName. */ + public tensorName: string; + + /** TensorAnnotation quantParameterTensorNames. */ + public quantParameterTensorNames: onnx.IStringStringEntryProto[]; + + /** + * Creates a new TensorAnnotation instance using the specified properties. + * @param [properties] Properties to set + * @returns TensorAnnotation instance + */ + public static create(properties?: onnx.ITensorAnnotation): onnx.TensorAnnotation; + + /** + * Encodes the specified TensorAnnotation message. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} + * messages. + * @param message TensorAnnotation message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITensorAnnotation, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TensorAnnotation message, length delimited. Does not implicitly {@link + * onnx.TensorAnnotation.verify|verify} messages. + * @param message TensorAnnotation message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITensorAnnotation, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorAnnotation; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorAnnotation; + + /** + * Verifies a TensorAnnotation message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TensorAnnotation message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TensorAnnotation + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorAnnotation; + + /** + * Creates a plain object from a TensorAnnotation message. Also converts values to other types if specified. + * @param message TensorAnnotation + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorAnnotation, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TensorAnnotation to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TensorAnnotation + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a GraphProto. */ + interface IGraphProto { + /** GraphProto node */ + node?: (onnx.INodeProto[]|null); + + /** GraphProto name */ + name?: (string|null); + + /** GraphProto initializer */ + initializer?: (onnx.ITensorProto[]|null); + + /** GraphProto sparseInitializer */ + sparseInitializer?: (onnx.ISparseTensorProto[]|null); + + /** GraphProto docString */ + docString?: (string|null); + + /** GraphProto input */ + input?: (onnx.IValueInfoProto[]|null); + + /** GraphProto output */ + output?: (onnx.IValueInfoProto[]|null); + + /** GraphProto valueInfo */ + valueInfo?: (onnx.IValueInfoProto[]|null); + + /** GraphProto quantizationAnnotation */ + quantizationAnnotation?: (onnx.ITensorAnnotation[]|null); + } + + /** Represents a GraphProto. */ + class GraphProto implements IGraphProto { + /** + * Constructs a new GraphProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IGraphProto); + + /** GraphProto node. */ + public node: onnx.INodeProto[]; + + /** GraphProto name. */ + public name: string; + + /** GraphProto initializer. */ + public initializer: onnx.ITensorProto[]; + + /** GraphProto sparseInitializer. */ + public sparseInitializer: onnx.ISparseTensorProto[]; + + /** GraphProto docString. */ + public docString: string; + + /** GraphProto input. */ + public input: onnx.IValueInfoProto[]; + + /** GraphProto output. */ + public output: onnx.IValueInfoProto[]; + + /** GraphProto valueInfo. */ + public valueInfo: onnx.IValueInfoProto[]; + + /** GraphProto quantizationAnnotation. */ + public quantizationAnnotation: onnx.ITensorAnnotation[]; + + /** + * Creates a new GraphProto instance using the specified properties. + * @param [properties] Properties to set + * @returns GraphProto instance + */ + public static create(properties?: onnx.IGraphProto): onnx.GraphProto; + + /** + * Encodes the specified GraphProto message. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @param message GraphProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IGraphProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified GraphProto message, length delimited. Does not implicitly {@link + * onnx.GraphProto.verify|verify} messages. + * @param message GraphProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IGraphProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a GraphProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.GraphProto; + + /** + * Decodes a GraphProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.GraphProto; + + /** + * Verifies a GraphProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a GraphProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns GraphProto + */ + public static fromObject(object: {[k: string]: any}): onnx.GraphProto; + + /** + * Creates a plain object from a GraphProto message. Also converts values to other types if specified. + * @param message GraphProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.GraphProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this GraphProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for GraphProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TensorProto. */ + interface ITensorProto { + /** TensorProto dims */ + dims?: ((number | Long)[]|null); + + /** TensorProto dataType */ + dataType?: (number|null); + + /** TensorProto segment */ + segment?: (onnx.TensorProto.ISegment|null); + + /** TensorProto floatData */ + floatData?: (number[]|null); + + /** TensorProto int32Data */ + int32Data?: (number[]|null); + + /** TensorProto stringData */ + stringData?: (Uint8Array[]|null); + + /** TensorProto int64Data */ + int64Data?: ((number | Long)[]|null); + + /** TensorProto name */ + name?: (string|null); + + /** TensorProto docString */ + docString?: (string|null); + + /** TensorProto rawData */ + rawData?: (Uint8Array|null); + + /** TensorProto externalData */ + externalData?: (onnx.IStringStringEntryProto[]|null); + + /** TensorProto dataLocation */ + dataLocation?: (onnx.TensorProto.DataLocation|null); + + /** TensorProto doubleData */ + doubleData?: (number[]|null); + + /** TensorProto uint64Data */ + uint64Data?: ((number | Long)[]|null); + } + + /** Represents a TensorProto. */ + class TensorProto implements ITensorProto { + /** + * Constructs a new TensorProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITensorProto); + + /** TensorProto dims. */ + public dims: (number|Long)[]; + + /** TensorProto dataType. */ + public dataType: number; + + /** TensorProto segment. */ + public segment?: (onnx.TensorProto.ISegment|null); + + /** TensorProto floatData. */ + public floatData: number[]; + + /** TensorProto int32Data. */ + public int32Data: number[]; + + /** TensorProto stringData. */ + public stringData: Uint8Array[]; + + /** TensorProto int64Data. */ + public int64Data: (number|Long)[]; + + /** TensorProto name. */ + public name: string; + + /** TensorProto docString. */ + public docString: string; + + /** TensorProto rawData. */ + public rawData: Uint8Array; + + /** TensorProto externalData. */ + public externalData: onnx.IStringStringEntryProto[]; + + /** TensorProto dataLocation. */ + public dataLocation: onnx.TensorProto.DataLocation; + + /** TensorProto doubleData. */ + public doubleData: number[]; + + /** TensorProto uint64Data. */ + public uint64Data: (number|Long)[]; + + /** + * Creates a new TensorProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TensorProto instance + */ + public static create(properties?: onnx.ITensorProto): onnx.TensorProto; + + /** + * Encodes the specified TensorProto message. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @param message TensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TensorProto message, length delimited. Does not implicitly {@link + * onnx.TensorProto.verify|verify} messages. + * @param message TensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TensorProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorProto; + + /** + * Decodes a TensorProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorProto; + + /** + * Verifies a TensorProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TensorProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TensorProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorProto; + + /** + * Creates a plain object from a TensorProto message. Also converts values to other types if specified. + * @param message TensorProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TensorProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TensorProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace TensorProto { + + /** DataType enum. */ + enum DataType { + UNDEFINED = 0, + FLOAT = 1, + UINT8 = 2, + INT8 = 3, + UINT16 = 4, + INT16 = 5, + INT32 = 6, + INT64 = 7, + STRING = 8, + BOOL = 9, + FLOAT16 = 10, + DOUBLE = 11, + UINT32 = 12, + UINT64 = 13, + COMPLEX64 = 14, + COMPLEX128 = 15, + BFLOAT16 = 16, + FLOAT8E4M3FN = 17, + FLOAT8E4M3FNUZ = 18, + FLOAT8E5M2 = 19, + FLOAT8E5M2FNUZ = 20 + } + + /** Properties of a Segment. */ + interface ISegment { + /** Segment begin */ + begin?: (number|Long|null); + + /** Segment end */ + end?: (number|Long|null); + } + + /** Represents a Segment. */ + class Segment implements ISegment { + /** + * Constructs a new Segment. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TensorProto.ISegment); + + /** Segment begin. */ + public begin: (number|Long); + + /** Segment end. */ + public end: (number|Long); + + /** + * Creates a new Segment instance using the specified properties. + * @param [properties] Properties to set + * @returns Segment instance + */ + public static create(properties?: onnx.TensorProto.ISegment): onnx.TensorProto.Segment; + + /** + * Encodes the specified Segment message. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} + * messages. + * @param message Segment message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TensorProto.ISegment, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Segment message, length delimited. Does not implicitly {@link + * onnx.TensorProto.Segment.verify|verify} messages. + * @param message Segment message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TensorProto.ISegment, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Segment message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorProto.Segment; + + /** + * Decodes a Segment message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorProto.Segment; + + /** + * Verifies a Segment message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Segment message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Segment + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorProto.Segment; + + /** + * Creates a plain object from a Segment message. Also converts values to other types if specified. + * @param message Segment + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorProto.Segment, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Segment to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Segment + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** DataLocation enum. */ + enum DataLocation { DEFAULT = 0, EXTERNAL = 1 } + } + + /** Properties of a SparseTensorProto. */ + interface ISparseTensorProto { + /** SparseTensorProto values */ + values?: (onnx.ITensorProto|null); + + /** SparseTensorProto indices */ + indices?: (onnx.ITensorProto|null); + + /** SparseTensorProto dims */ + dims?: ((number | Long)[]|null); + } + + /** Represents a SparseTensorProto. */ + class SparseTensorProto implements ISparseTensorProto { + /** + * Constructs a new SparseTensorProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ISparseTensorProto); + + /** SparseTensorProto values. */ + public values?: (onnx.ITensorProto|null); + + /** SparseTensorProto indices. */ + public indices?: (onnx.ITensorProto|null); + + /** SparseTensorProto dims. */ + public dims: (number|Long)[]; + + /** + * Creates a new SparseTensorProto instance using the specified properties. + * @param [properties] Properties to set + * @returns SparseTensorProto instance + */ + public static create(properties?: onnx.ISparseTensorProto): onnx.SparseTensorProto; + + /** + * Encodes the specified SparseTensorProto message. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} + * messages. + * @param message SparseTensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ISparseTensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified SparseTensorProto message, length delimited. Does not implicitly {@link + * onnx.SparseTensorProto.verify|verify} messages. + * @param message SparseTensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ISparseTensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.SparseTensorProto; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.SparseTensorProto; + + /** + * Verifies a SparseTensorProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a SparseTensorProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns SparseTensorProto + */ + public static fromObject(object: {[k: string]: any}): onnx.SparseTensorProto; + + /** + * Creates a plain object from a SparseTensorProto message. Also converts values to other types if specified. + * @param message SparseTensorProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.SparseTensorProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this SparseTensorProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for SparseTensorProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TensorShapeProto. */ + interface ITensorShapeProto { + /** TensorShapeProto dim */ + dim?: (onnx.TensorShapeProto.IDimension[]|null); + } + + /** Represents a TensorShapeProto. */ + class TensorShapeProto implements ITensorShapeProto { + /** + * Constructs a new TensorShapeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITensorShapeProto); + + /** TensorShapeProto dim. */ + public dim: onnx.TensorShapeProto.IDimension[]; + + /** + * Creates a new TensorShapeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TensorShapeProto instance + */ + public static create(properties?: onnx.ITensorShapeProto): onnx.TensorShapeProto; + + /** + * Encodes the specified TensorShapeProto message. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} + * messages. + * @param message TensorShapeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITensorShapeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TensorShapeProto message, length delimited. Does not implicitly {@link + * onnx.TensorShapeProto.verify|verify} messages. + * @param message TensorShapeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITensorShapeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorShapeProto; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorShapeProto; + + /** + * Verifies a TensorShapeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TensorShapeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TensorShapeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorShapeProto; + + /** + * Creates a plain object from a TensorShapeProto message. Also converts values to other types if specified. + * @param message TensorShapeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorShapeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TensorShapeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TensorShapeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace TensorShapeProto { + + /** Properties of a Dimension. */ + interface IDimension { + /** Dimension dimValue */ + dimValue?: (number|Long|null); + + /** Dimension dimParam */ + dimParam?: (string|null); + + /** Dimension denotation */ + denotation?: (string|null); + } + + /** Represents a Dimension. */ + class Dimension implements IDimension { + /** + * Constructs a new Dimension. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TensorShapeProto.IDimension); + + /** Dimension dimValue. */ + public dimValue?: (number|Long|null); + + /** Dimension dimParam. */ + public dimParam?: (string|null); + + /** Dimension denotation. */ + public denotation: string; + + /** Dimension value. */ + public value?: ('dimValue'|'dimParam'); + + /** + * Creates a new Dimension instance using the specified properties. + * @param [properties] Properties to set + * @returns Dimension instance + */ + public static create(properties?: onnx.TensorShapeProto.IDimension): onnx.TensorShapeProto.Dimension; + + /** + * Encodes the specified Dimension message. Does not implicitly {@link + * onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @param message Dimension message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TensorShapeProto.IDimension, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Dimension message, length delimited. Does not implicitly {@link + * onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @param message Dimension message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TensorShapeProto.IDimension, writer?: $protobuf.Writer): + $protobuf.Writer; + + /** + * Decodes a Dimension message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorShapeProto.Dimension; + + /** + * Decodes a Dimension message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorShapeProto.Dimension; + + /** + * Verifies a Dimension message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Dimension message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Dimension + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorShapeProto.Dimension; + + /** + * Creates a plain object from a Dimension message. Also converts values to other types if specified. + * @param message Dimension + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorShapeProto.Dimension, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Dimension to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Dimension + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + } + + /** Properties of a TypeProto. */ + interface ITypeProto { + /** TypeProto tensorType */ + tensorType?: (onnx.TypeProto.ITensor|null); + + /** TypeProto sequenceType */ + sequenceType?: (onnx.TypeProto.ISequence|null); + + /** TypeProto mapType */ + mapType?: (onnx.TypeProto.IMap|null); + + /** TypeProto optionalType */ + optionalType?: (onnx.TypeProto.IOptional|null); + + /** TypeProto sparseTensorType */ + sparseTensorType?: (onnx.TypeProto.ISparseTensor|null); + + /** TypeProto denotation */ + denotation?: (string|null); + } + + /** Represents a TypeProto. */ + class TypeProto implements ITypeProto { + /** + * Constructs a new TypeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITypeProto); + + /** TypeProto tensorType. */ + public tensorType?: (onnx.TypeProto.ITensor|null); + + /** TypeProto sequenceType. */ + public sequenceType?: (onnx.TypeProto.ISequence|null); + + /** TypeProto mapType. */ + public mapType?: (onnx.TypeProto.IMap|null); + + /** TypeProto optionalType. */ + public optionalType?: (onnx.TypeProto.IOptional|null); + + /** TypeProto sparseTensorType. */ + public sparseTensorType?: (onnx.TypeProto.ISparseTensor|null); + + /** TypeProto denotation. */ + public denotation: string; + + /** TypeProto value. */ + public value?: ('tensorType'|'sequenceType'|'mapType'|'optionalType'|'sparseTensorType'); + + /** + * Creates a new TypeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TypeProto instance + */ + public static create(properties?: onnx.ITypeProto): onnx.TypeProto; + + /** + * Encodes the specified TypeProto message. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @param message TypeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITypeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TypeProto message, length delimited. Does not implicitly {@link + * onnx.TypeProto.verify|verify} messages. + * @param message TypeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITypeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TypeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto; + + /** + * Decodes a TypeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto; + + /** + * Verifies a TypeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TypeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TypeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto; + + /** + * Creates a plain object from a TypeProto message. Also converts values to other types if specified. + * @param message TypeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TypeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TypeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace TypeProto { + + /** Properties of a Tensor. */ + interface ITensor { + /** Tensor elemType */ + elemType?: (number|null); + + /** Tensor shape */ + shape?: (onnx.ITensorShapeProto|null); + } + + /** Represents a Tensor. */ + class Tensor implements ITensor { + /** + * Constructs a new Tensor. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.ITensor); + + /** Tensor elemType. */ + public elemType: number; + + /** Tensor shape. */ + public shape?: (onnx.ITensorShapeProto|null); + + /** + * Creates a new Tensor instance using the specified properties. + * @param [properties] Properties to set + * @returns Tensor instance + */ + public static create(properties?: onnx.TypeProto.ITensor): onnx.TypeProto.Tensor; + + /** + * Encodes the specified Tensor message. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @param message Tensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.ITensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Tensor message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Tensor.verify|verify} messages. + * @param message Tensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.ITensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Tensor message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Tensor; + + /** + * Decodes a Tensor message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Tensor; + + /** + * Verifies a Tensor message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Tensor message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Tensor + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Tensor; + + /** + * Creates a plain object from a Tensor message. Also converts values to other types if specified. + * @param message Tensor + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Tensor, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Tensor to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Tensor + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a Sequence. */ + interface ISequence { + /** Sequence elemType */ + elemType?: (onnx.ITypeProto|null); + } + + /** Represents a Sequence. */ + class Sequence implements ISequence { + /** + * Constructs a new Sequence. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.ISequence); + + /** Sequence elemType. */ + public elemType?: (onnx.ITypeProto|null); + + /** + * Creates a new Sequence instance using the specified properties. + * @param [properties] Properties to set + * @returns Sequence instance + */ + public static create(properties?: onnx.TypeProto.ISequence): onnx.TypeProto.Sequence; + + /** + * Encodes the specified Sequence message. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} + * messages. + * @param message Sequence message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.ISequence, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Sequence message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Sequence.verify|verify} messages. + * @param message Sequence message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.ISequence, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Sequence message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Sequence; + + /** + * Decodes a Sequence message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Sequence; + + /** + * Verifies a Sequence message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Sequence message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Sequence + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Sequence; + + /** + * Creates a plain object from a Sequence message. Also converts values to other types if specified. + * @param message Sequence + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Sequence, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Sequence to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Sequence + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a Map. */ + interface IMap { + /** Map keyType */ + keyType?: (number|null); + + /** Map valueType */ + valueType?: (onnx.ITypeProto|null); + } + + /** Represents a Map. */ + class Map implements IMap { + /** + * Constructs a new Map. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.IMap); + + /** Map keyType. */ + public keyType: number; + + /** Map valueType. */ + public valueType?: (onnx.ITypeProto|null); + + /** + * Creates a new Map instance using the specified properties. + * @param [properties] Properties to set + * @returns Map instance + */ + public static create(properties?: onnx.TypeProto.IMap): onnx.TypeProto.Map; + + /** + * Encodes the specified Map message. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @param message Map message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.IMap, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Map message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Map.verify|verify} messages. + * @param message Map message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.IMap, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Map message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Map; + + /** + * Decodes a Map message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Map; + + /** + * Verifies a Map message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Map message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Map + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Map; + + /** + * Creates a plain object from a Map message. Also converts values to other types if specified. + * @param message Map + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Map, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this Map to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Map + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of an Optional. */ + interface IOptional { + /** Optional elemType */ + elemType?: (onnx.ITypeProto|null); + } + + /** Represents an Optional. */ + class Optional implements IOptional { + /** + * Constructs a new Optional. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.IOptional); + + /** Optional elemType. */ + public elemType?: (onnx.ITypeProto|null); + + /** + * Creates a new Optional instance using the specified properties. + * @param [properties] Properties to set + * @returns Optional instance + */ + public static create(properties?: onnx.TypeProto.IOptional): onnx.TypeProto.Optional; + + /** + * Encodes the specified Optional message. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} + * messages. + * @param message Optional message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.IOptional, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Optional message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Optional.verify|verify} messages. + * @param message Optional message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.IOptional, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes an Optional message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Optional; + + /** + * Decodes an Optional message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Optional; + + /** + * Verifies an Optional message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates an Optional message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Optional + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Optional; + + /** + * Creates a plain object from an Optional message. Also converts values to other types if specified. + * @param message Optional + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Optional, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Optional to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Optional + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a SparseTensor. */ + interface ISparseTensor { + /** SparseTensor elemType */ + elemType?: (number|null); + + /** SparseTensor shape */ + shape?: (onnx.ITensorShapeProto|null); + } + + /** Represents a SparseTensor. */ + class SparseTensor implements ISparseTensor { + /** + * Constructs a new SparseTensor. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.ISparseTensor); + + /** SparseTensor elemType. */ + public elemType: number; + + /** SparseTensor shape. */ + public shape?: (onnx.ITensorShapeProto|null); + + /** + * Creates a new SparseTensor instance using the specified properties. + * @param [properties] Properties to set + * @returns SparseTensor instance + */ + public static create(properties?: onnx.TypeProto.ISparseTensor): onnx.TypeProto.SparseTensor; + + /** + * Encodes the specified SparseTensor message. Does not implicitly {@link + * onnx.TypeProto.SparseTensor.verify|verify} messages. + * @param message SparseTensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.ISparseTensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified SparseTensor message, length delimited. Does not implicitly {@link + * onnx.TypeProto.SparseTensor.verify|verify} messages. + * @param message SparseTensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.ISparseTensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a SparseTensor message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.SparseTensor; + + /** + * Decodes a SparseTensor message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.SparseTensor; + + /** + * Verifies a SparseTensor message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a SparseTensor message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns SparseTensor + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.SparseTensor; + + /** + * Creates a plain object from a SparseTensor message. Also converts values to other types if specified. + * @param message SparseTensor + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.SparseTensor, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this SparseTensor to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for SparseTensor + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + } + + /** Properties of an OperatorSetIdProto. */ + interface IOperatorSetIdProto { + /** OperatorSetIdProto domain */ + domain?: (string|null); + + /** OperatorSetIdProto version */ + version?: (number|Long|null); + } + + /** Represents an OperatorSetIdProto. */ + class OperatorSetIdProto implements IOperatorSetIdProto { + /** + * Constructs a new OperatorSetIdProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IOperatorSetIdProto); + + /** OperatorSetIdProto domain. */ + public domain: string; + + /** OperatorSetIdProto version. */ + public version: (number|Long); + + /** + * Creates a new OperatorSetIdProto instance using the specified properties. + * @param [properties] Properties to set + * @returns OperatorSetIdProto instance + */ + public static create(properties?: onnx.IOperatorSetIdProto): onnx.OperatorSetIdProto; + + /** + * Encodes the specified OperatorSetIdProto message. Does not implicitly {@link + * onnx.OperatorSetIdProto.verify|verify} messages. + * @param message OperatorSetIdProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IOperatorSetIdProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified OperatorSetIdProto message, length delimited. Does not implicitly {@link + * onnx.OperatorSetIdProto.verify|verify} messages. + * @param message OperatorSetIdProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IOperatorSetIdProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.OperatorSetIdProto; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.OperatorSetIdProto; + + /** + * Verifies an OperatorSetIdProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates an OperatorSetIdProto message from a plain object. Also converts values to their respective internal + * types. + * @param object Plain object + * @returns OperatorSetIdProto + */ + public static fromObject(object: {[k: string]: any}): onnx.OperatorSetIdProto; + + /** + * Creates a plain object from an OperatorSetIdProto message. Also converts values to other types if specified. + * @param message OperatorSetIdProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.OperatorSetIdProto, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this OperatorSetIdProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for OperatorSetIdProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** OperatorStatus enum. */ + enum OperatorStatus { EXPERIMENTAL = 0, STABLE = 1 } + + /** Properties of a FunctionProto. */ + interface IFunctionProto { + /** FunctionProto name */ + name?: (string|null); + + /** FunctionProto input */ + input?: (string[]|null); + + /** FunctionProto output */ + output?: (string[]|null); + + /** FunctionProto attribute */ + attribute?: (string[]|null); + + /** FunctionProto attributeProto */ + attributeProto?: (onnx.IAttributeProto[]|null); + + /** FunctionProto node */ + node?: (onnx.INodeProto[]|null); + + /** FunctionProto docString */ + docString?: (string|null); + + /** FunctionProto opsetImport */ + opsetImport?: (onnx.IOperatorSetIdProto[]|null); + + /** FunctionProto domain */ + domain?: (string|null); + } + + /** Represents a FunctionProto. */ + class FunctionProto implements IFunctionProto { + /** + * Constructs a new FunctionProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IFunctionProto); + + /** FunctionProto name. */ + public name: string; + + /** FunctionProto input. */ + public input: string[]; + + /** FunctionProto output. */ + public output: string[]; + + /** FunctionProto attribute. */ + public attribute: string[]; + + /** FunctionProto attributeProto. */ + public attributeProto: onnx.IAttributeProto[]; + + /** FunctionProto node. */ + public node: onnx.INodeProto[]; + + /** FunctionProto docString. */ + public docString: string; + + /** FunctionProto opsetImport. */ + public opsetImport: onnx.IOperatorSetIdProto[]; + + /** FunctionProto domain. */ + public domain: string; + + /** + * Creates a new FunctionProto instance using the specified properties. + * @param [properties] Properties to set + * @returns FunctionProto instance + */ + public static create(properties?: onnx.IFunctionProto): onnx.FunctionProto; + + /** + * Encodes the specified FunctionProto message. Does not implicitly {@link onnx.FunctionProto.verify|verify} + * messages. + * @param message FunctionProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IFunctionProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified FunctionProto message, length delimited. Does not implicitly {@link + * onnx.FunctionProto.verify|verify} messages. + * @param message FunctionProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IFunctionProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a FunctionProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.FunctionProto; + + /** + * Decodes a FunctionProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.FunctionProto; + + /** + * Verifies a FunctionProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a FunctionProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns FunctionProto + */ + public static fromObject(object: {[k: string]: any}): onnx.FunctionProto; + + /** + * Creates a plain object from a FunctionProto message. Also converts values to other types if specified. + * @param message FunctionProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.FunctionProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this FunctionProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for FunctionProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } +} diff --git a/js/node/test/ort-schema/protobuf/onnx.js b/js/node/test/ort-schema/protobuf/onnx.js new file mode 100644 index 0000000000000..681855132d4e8 --- /dev/null +++ b/js/node/test/ort-schema/protobuf/onnx.js @@ -0,0 +1,7658 @@ +/*eslint-disable block-scoped-var, id-length, no-control-regex, no-magic-numbers, no-prototype-builtins, no-redeclare, no-shadow, no-var, sort-vars*/ +"use strict"; + +var $protobuf = require("protobufjs/minimal"); + +// Common aliases +var $Reader = $protobuf.Reader, $Writer = $protobuf.Writer, $util = $protobuf.util; + +// Exported root namespace +var $root = $protobuf.roots["default"] || ($protobuf.roots["default"] = {}); + +$root.onnx = (function() { + + /** + * Namespace onnx. + * @exports onnx + * @namespace + */ + var onnx = {}; + + /** + * Version enum. + * @name onnx.Version + * @enum {number} + * @property {number} _START_VERSION=0 _START_VERSION value + * @property {number} IR_VERSION_2017_10_10=1 IR_VERSION_2017_10_10 value + * @property {number} IR_VERSION_2017_10_30=2 IR_VERSION_2017_10_30 value + * @property {number} IR_VERSION_2017_11_3=3 IR_VERSION_2017_11_3 value + * @property {number} IR_VERSION_2019_1_22=4 IR_VERSION_2019_1_22 value + * @property {number} IR_VERSION_2019_3_18=5 IR_VERSION_2019_3_18 value + * @property {number} IR_VERSION_2019_9_19=6 IR_VERSION_2019_9_19 value + * @property {number} IR_VERSION_2020_5_8=7 IR_VERSION_2020_5_8 value + * @property {number} IR_VERSION_2021_7_30=8 IR_VERSION_2021_7_30 value + * @property {number} IR_VERSION=9 IR_VERSION value + */ + onnx.Version = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "_START_VERSION"] = 0; + values[valuesById[1] = "IR_VERSION_2017_10_10"] = 1; + values[valuesById[2] = "IR_VERSION_2017_10_30"] = 2; + values[valuesById[3] = "IR_VERSION_2017_11_3"] = 3; + values[valuesById[4] = "IR_VERSION_2019_1_22"] = 4; + values[valuesById[5] = "IR_VERSION_2019_3_18"] = 5; + values[valuesById[6] = "IR_VERSION_2019_9_19"] = 6; + values[valuesById[7] = "IR_VERSION_2020_5_8"] = 7; + values[valuesById[8] = "IR_VERSION_2021_7_30"] = 8; + values[valuesById[9] = "IR_VERSION"] = 9; + return values; + })(); + + onnx.AttributeProto = (function() { + + /** + * Properties of an AttributeProto. + * @memberof onnx + * @interface IAttributeProto + * @property {string|null} [name] AttributeProto name + * @property {string|null} [refAttrName] AttributeProto refAttrName + * @property {string|null} [docString] AttributeProto docString + * @property {onnx.AttributeProto.AttributeType|null} [type] AttributeProto type + * @property {number|null} [f] AttributeProto f + * @property {number|Long|null} [i] AttributeProto i + * @property {Uint8Array|null} [s] AttributeProto s + * @property {onnx.ITensorProto|null} [t] AttributeProto t + * @property {onnx.IGraphProto|null} [g] AttributeProto g + * @property {onnx.ISparseTensorProto|null} [sparseTensor] AttributeProto sparseTensor + * @property {onnx.ITypeProto|null} [tp] AttributeProto tp + * @property {Array.|null} [floats] AttributeProto floats + * @property {Array.|null} [ints] AttributeProto ints + * @property {Array.|null} [strings] AttributeProto strings + * @property {Array.|null} [tensors] AttributeProto tensors + * @property {Array.|null} [graphs] AttributeProto graphs + * @property {Array.|null} [sparseTensors] AttributeProto sparseTensors + * @property {Array.|null} [typeProtos] AttributeProto typeProtos + */ + + /** + * Constructs a new AttributeProto. + * @memberof onnx + * @classdesc Represents an AttributeProto. + * @implements IAttributeProto + * @constructor + * @param {onnx.IAttributeProto=} [properties] Properties to set + */ + function AttributeProto(properties) { + this.floats = []; + this.ints = []; + this.strings = []; + this.tensors = []; + this.graphs = []; + this.sparseTensors = []; + this.typeProtos = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * AttributeProto name. + * @member {string} name + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.name = ""; + + /** + * AttributeProto refAttrName. + * @member {string} refAttrName + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.refAttrName = ""; + + /** + * AttributeProto docString. + * @member {string} docString + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.docString = ""; + + /** + * AttributeProto type. + * @member {onnx.AttributeProto.AttributeType} type + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.type = 0; + + /** + * AttributeProto f. + * @member {number} f + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.f = 0; + + /** + * AttributeProto i. + * @member {number|Long} i + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.i = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * AttributeProto s. + * @member {Uint8Array} s + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.s = $util.newBuffer([]); + + /** + * AttributeProto t. + * @member {onnx.ITensorProto|null|undefined} t + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.t = null; + + /** + * AttributeProto g. + * @member {onnx.IGraphProto|null|undefined} g + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.g = null; + + /** + * AttributeProto sparseTensor. + * @member {onnx.ISparseTensorProto|null|undefined} sparseTensor + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.sparseTensor = null; + + /** + * AttributeProto tp. + * @member {onnx.ITypeProto|null|undefined} tp + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.tp = null; + + /** + * AttributeProto floats. + * @member {Array.} floats + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.floats = $util.emptyArray; + + /** + * AttributeProto ints. + * @member {Array.} ints + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.ints = $util.emptyArray; + + /** + * AttributeProto strings. + * @member {Array.} strings + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.strings = $util.emptyArray; + + /** + * AttributeProto tensors. + * @member {Array.} tensors + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.tensors = $util.emptyArray; + + /** + * AttributeProto graphs. + * @member {Array.} graphs + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.graphs = $util.emptyArray; + + /** + * AttributeProto sparseTensors. + * @member {Array.} sparseTensors + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.sparseTensors = $util.emptyArray; + + /** + * AttributeProto typeProtos. + * @member {Array.} typeProtos + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.typeProtos = $util.emptyArray; + + /** + * Creates a new AttributeProto instance using the specified properties. + * @function create + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto=} [properties] Properties to set + * @returns {onnx.AttributeProto} AttributeProto instance + */ + AttributeProto.create = function create(properties) { + return new AttributeProto(properties); + }; + + /** + * Encodes the specified AttributeProto message. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. + * @function encode + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + AttributeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); + if (message.f != null && Object.hasOwnProperty.call(message, "f")) + writer.uint32(/* id 2, wireType 5 =*/21).float(message.f); + if (message.i != null && Object.hasOwnProperty.call(message, "i")) + writer.uint32(/* id 3, wireType 0 =*/24).int64(message.i); + if (message.s != null && Object.hasOwnProperty.call(message, "s")) + writer.uint32(/* id 4, wireType 2 =*/34).bytes(message.s); + if (message.t != null && Object.hasOwnProperty.call(message, "t")) + $root.onnx.TensorProto.encode(message.t, writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.g != null && Object.hasOwnProperty.call(message, "g")) + $root.onnx.GraphProto.encode(message.g, writer.uint32(/* id 6, wireType 2 =*/50).fork()).ldelim(); + if (message.floats != null && message.floats.length) { + writer.uint32(/* id 7, wireType 2 =*/58).fork(); + for (var i = 0; i < message.floats.length; ++i) + writer.float(message.floats[i]); + writer.ldelim(); + } + if (message.ints != null && message.ints.length) { + writer.uint32(/* id 8, wireType 2 =*/66).fork(); + for (var i = 0; i < message.ints.length; ++i) + writer.int64(message.ints[i]); + writer.ldelim(); + } + if (message.strings != null && message.strings.length) + for (var i = 0; i < message.strings.length; ++i) + writer.uint32(/* id 9, wireType 2 =*/74).bytes(message.strings[i]); + if (message.tensors != null && message.tensors.length) + for (var i = 0; i < message.tensors.length; ++i) + $root.onnx.TensorProto.encode(message.tensors[i], writer.uint32(/* id 10, wireType 2 =*/82).fork()).ldelim(); + if (message.graphs != null && message.graphs.length) + for (var i = 0; i < message.graphs.length; ++i) + $root.onnx.GraphProto.encode(message.graphs[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 13, wireType 2 =*/106).string(message.docString); + if (message.tp != null && Object.hasOwnProperty.call(message, "tp")) + $root.onnx.TypeProto.encode(message.tp, writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); + if (message.typeProtos != null && message.typeProtos.length) + for (var i = 0; i < message.typeProtos.length; ++i) + $root.onnx.TypeProto.encode(message.typeProtos[i], writer.uint32(/* id 15, wireType 2 =*/122).fork()).ldelim(); + if (message.type != null && Object.hasOwnProperty.call(message, "type")) + writer.uint32(/* id 20, wireType 0 =*/160).int32(message.type); + if (message.refAttrName != null && Object.hasOwnProperty.call(message, "refAttrName")) + writer.uint32(/* id 21, wireType 2 =*/170).string(message.refAttrName); + if (message.sparseTensor != null && Object.hasOwnProperty.call(message, "sparseTensor")) + $root.onnx.SparseTensorProto.encode(message.sparseTensor, writer.uint32(/* id 22, wireType 2 =*/178).fork()).ldelim(); + if (message.sparseTensors != null && message.sparseTensors.length) + for (var i = 0; i < message.sparseTensors.length; ++i) + $root.onnx.SparseTensorProto.encode(message.sparseTensors[i], writer.uint32(/* id 23, wireType 2 =*/186).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified AttributeProto message, length delimited. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + AttributeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an AttributeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.AttributeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.AttributeProto} AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + AttributeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.AttributeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 21: { + message.refAttrName = reader.string(); + break; + } + case 13: { + message.docString = reader.string(); + break; + } + case 20: { + message.type = reader.int32(); + break; + } + case 2: { + message.f = reader.float(); + break; + } + case 3: { + message.i = reader.int64(); + break; + } + case 4: { + message.s = reader.bytes(); + break; + } + case 5: { + message.t = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 6: { + message.g = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 22: { + message.sparseTensor = $root.onnx.SparseTensorProto.decode(reader, reader.uint32()); + break; + } + case 14: { + message.tp = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + case 7: { + if (!(message.floats && message.floats.length)) + message.floats = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.floats.push(reader.float()); + } else + message.floats.push(reader.float()); + break; + } + case 8: { + if (!(message.ints && message.ints.length)) + message.ints = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.ints.push(reader.int64()); + } else + message.ints.push(reader.int64()); + break; + } + case 9: { + if (!(message.strings && message.strings.length)) + message.strings = []; + message.strings.push(reader.bytes()); + break; + } + case 10: { + if (!(message.tensors && message.tensors.length)) + message.tensors = []; + message.tensors.push($root.onnx.TensorProto.decode(reader, reader.uint32())); + break; + } + case 11: { + if (!(message.graphs && message.graphs.length)) + message.graphs = []; + message.graphs.push($root.onnx.GraphProto.decode(reader, reader.uint32())); + break; + } + case 23: { + if (!(message.sparseTensors && message.sparseTensors.length)) + message.sparseTensors = []; + message.sparseTensors.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); + break; + } + case 15: { + if (!(message.typeProtos && message.typeProtos.length)) + message.typeProtos = []; + message.typeProtos.push($root.onnx.TypeProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an AttributeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.AttributeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.AttributeProto} AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + AttributeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an AttributeProto message. + * @function verify + * @memberof onnx.AttributeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + AttributeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.refAttrName != null && message.hasOwnProperty("refAttrName")) + if (!$util.isString(message.refAttrName)) + return "refAttrName: string expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.type != null && message.hasOwnProperty("type")) + switch (message.type) { + default: + return "type: enum value expected"; + case 0: + case 1: + case 2: + case 3: + case 4: + case 5: + case 11: + case 13: + case 6: + case 7: + case 8: + case 9: + case 10: + case 12: + case 14: + break; + } + if (message.f != null && message.hasOwnProperty("f")) + if (typeof message.f !== "number") + return "f: number expected"; + if (message.i != null && message.hasOwnProperty("i")) + if (!$util.isInteger(message.i) && !(message.i && $util.isInteger(message.i.low) && $util.isInteger(message.i.high))) + return "i: integer|Long expected"; + if (message.s != null && message.hasOwnProperty("s")) + if (!(message.s && typeof message.s.length === "number" || $util.isString(message.s))) + return "s: buffer expected"; + if (message.t != null && message.hasOwnProperty("t")) { + var error = $root.onnx.TensorProto.verify(message.t); + if (error) + return "t." + error; + } + if (message.g != null && message.hasOwnProperty("g")) { + var error = $root.onnx.GraphProto.verify(message.g); + if (error) + return "g." + error; + } + if (message.sparseTensor != null && message.hasOwnProperty("sparseTensor")) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseTensor); + if (error) + return "sparseTensor." + error; + } + if (message.tp != null && message.hasOwnProperty("tp")) { + var error = $root.onnx.TypeProto.verify(message.tp); + if (error) + return "tp." + error; + } + if (message.floats != null && message.hasOwnProperty("floats")) { + if (!Array.isArray(message.floats)) + return "floats: array expected"; + for (var i = 0; i < message.floats.length; ++i) + if (typeof message.floats[i] !== "number") + return "floats: number[] expected"; + } + if (message.ints != null && message.hasOwnProperty("ints")) { + if (!Array.isArray(message.ints)) + return "ints: array expected"; + for (var i = 0; i < message.ints.length; ++i) + if (!$util.isInteger(message.ints[i]) && !(message.ints[i] && $util.isInteger(message.ints[i].low) && $util.isInteger(message.ints[i].high))) + return "ints: integer|Long[] expected"; + } + if (message.strings != null && message.hasOwnProperty("strings")) { + if (!Array.isArray(message.strings)) + return "strings: array expected"; + for (var i = 0; i < message.strings.length; ++i) + if (!(message.strings[i] && typeof message.strings[i].length === "number" || $util.isString(message.strings[i]))) + return "strings: buffer[] expected"; + } + if (message.tensors != null && message.hasOwnProperty("tensors")) { + if (!Array.isArray(message.tensors)) + return "tensors: array expected"; + for (var i = 0; i < message.tensors.length; ++i) { + var error = $root.onnx.TensorProto.verify(message.tensors[i]); + if (error) + return "tensors." + error; + } + } + if (message.graphs != null && message.hasOwnProperty("graphs")) { + if (!Array.isArray(message.graphs)) + return "graphs: array expected"; + for (var i = 0; i < message.graphs.length; ++i) { + var error = $root.onnx.GraphProto.verify(message.graphs[i]); + if (error) + return "graphs." + error; + } + } + if (message.sparseTensors != null && message.hasOwnProperty("sparseTensors")) { + if (!Array.isArray(message.sparseTensors)) + return "sparseTensors: array expected"; + for (var i = 0; i < message.sparseTensors.length; ++i) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseTensors[i]); + if (error) + return "sparseTensors." + error; + } + } + if (message.typeProtos != null && message.hasOwnProperty("typeProtos")) { + if (!Array.isArray(message.typeProtos)) + return "typeProtos: array expected"; + for (var i = 0; i < message.typeProtos.length; ++i) { + var error = $root.onnx.TypeProto.verify(message.typeProtos[i]); + if (error) + return "typeProtos." + error; + } + } + return null; + }; + + /** + * Creates an AttributeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.AttributeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.AttributeProto} AttributeProto + */ + AttributeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.AttributeProto) + return object; + var message = new $root.onnx.AttributeProto(); + if (object.name != null) + message.name = String(object.name); + if (object.refAttrName != null) + message.refAttrName = String(object.refAttrName); + if (object.docString != null) + message.docString = String(object.docString); + switch (object.type) { + default: + if (typeof object.type === "number") { + message.type = object.type; + break; + } + break; + case "UNDEFINED": + case 0: + message.type = 0; + break; + case "FLOAT": + case 1: + message.type = 1; + break; + case "INT": + case 2: + message.type = 2; + break; + case "STRING": + case 3: + message.type = 3; + break; + case "TENSOR": + case 4: + message.type = 4; + break; + case "GRAPH": + case 5: + message.type = 5; + break; + case "SPARSE_TENSOR": + case 11: + message.type = 11; + break; + case "TYPE_PROTO": + case 13: + message.type = 13; + break; + case "FLOATS": + case 6: + message.type = 6; + break; + case "INTS": + case 7: + message.type = 7; + break; + case "STRINGS": + case 8: + message.type = 8; + break; + case "TENSORS": + case 9: + message.type = 9; + break; + case "GRAPHS": + case 10: + message.type = 10; + break; + case "SPARSE_TENSORS": + case 12: + message.type = 12; + break; + case "TYPE_PROTOS": + case 14: + message.type = 14; + break; + } + if (object.f != null) + message.f = Number(object.f); + if (object.i != null) + if ($util.Long) + (message.i = $util.Long.fromValue(object.i)).unsigned = false; + else if (typeof object.i === "string") + message.i = parseInt(object.i, 10); + else if (typeof object.i === "number") + message.i = object.i; + else if (typeof object.i === "object") + message.i = new $util.LongBits(object.i.low >>> 0, object.i.high >>> 0).toNumber(); + if (object.s != null) + if (typeof object.s === "string") + $util.base64.decode(object.s, message.s = $util.newBuffer($util.base64.length(object.s)), 0); + else if (object.s.length >= 0) + message.s = object.s; + if (object.t != null) { + if (typeof object.t !== "object") + throw TypeError(".onnx.AttributeProto.t: object expected"); + message.t = $root.onnx.TensorProto.fromObject(object.t); + } + if (object.g != null) { + if (typeof object.g !== "object") + throw TypeError(".onnx.AttributeProto.g: object expected"); + message.g = $root.onnx.GraphProto.fromObject(object.g); + } + if (object.sparseTensor != null) { + if (typeof object.sparseTensor !== "object") + throw TypeError(".onnx.AttributeProto.sparseTensor: object expected"); + message.sparseTensor = $root.onnx.SparseTensorProto.fromObject(object.sparseTensor); + } + if (object.tp != null) { + if (typeof object.tp !== "object") + throw TypeError(".onnx.AttributeProto.tp: object expected"); + message.tp = $root.onnx.TypeProto.fromObject(object.tp); + } + if (object.floats) { + if (!Array.isArray(object.floats)) + throw TypeError(".onnx.AttributeProto.floats: array expected"); + message.floats = []; + for (var i = 0; i < object.floats.length; ++i) + message.floats[i] = Number(object.floats[i]); + } + if (object.ints) { + if (!Array.isArray(object.ints)) + throw TypeError(".onnx.AttributeProto.ints: array expected"); + message.ints = []; + for (var i = 0; i < object.ints.length; ++i) + if ($util.Long) + (message.ints[i] = $util.Long.fromValue(object.ints[i])).unsigned = false; + else if (typeof object.ints[i] === "string") + message.ints[i] = parseInt(object.ints[i], 10); + else if (typeof object.ints[i] === "number") + message.ints[i] = object.ints[i]; + else if (typeof object.ints[i] === "object") + message.ints[i] = new $util.LongBits(object.ints[i].low >>> 0, object.ints[i].high >>> 0).toNumber(); + } + if (object.strings) { + if (!Array.isArray(object.strings)) + throw TypeError(".onnx.AttributeProto.strings: array expected"); + message.strings = []; + for (var i = 0; i < object.strings.length; ++i) + if (typeof object.strings[i] === "string") + $util.base64.decode(object.strings[i], message.strings[i] = $util.newBuffer($util.base64.length(object.strings[i])), 0); + else if (object.strings[i].length >= 0) + message.strings[i] = object.strings[i]; + } + if (object.tensors) { + if (!Array.isArray(object.tensors)) + throw TypeError(".onnx.AttributeProto.tensors: array expected"); + message.tensors = []; + for (var i = 0; i < object.tensors.length; ++i) { + if (typeof object.tensors[i] !== "object") + throw TypeError(".onnx.AttributeProto.tensors: object expected"); + message.tensors[i] = $root.onnx.TensorProto.fromObject(object.tensors[i]); + } + } + if (object.graphs) { + if (!Array.isArray(object.graphs)) + throw TypeError(".onnx.AttributeProto.graphs: array expected"); + message.graphs = []; + for (var i = 0; i < object.graphs.length; ++i) { + if (typeof object.graphs[i] !== "object") + throw TypeError(".onnx.AttributeProto.graphs: object expected"); + message.graphs[i] = $root.onnx.GraphProto.fromObject(object.graphs[i]); + } + } + if (object.sparseTensors) { + if (!Array.isArray(object.sparseTensors)) + throw TypeError(".onnx.AttributeProto.sparseTensors: array expected"); + message.sparseTensors = []; + for (var i = 0; i < object.sparseTensors.length; ++i) { + if (typeof object.sparseTensors[i] !== "object") + throw TypeError(".onnx.AttributeProto.sparseTensors: object expected"); + message.sparseTensors[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseTensors[i]); + } + } + if (object.typeProtos) { + if (!Array.isArray(object.typeProtos)) + throw TypeError(".onnx.AttributeProto.typeProtos: array expected"); + message.typeProtos = []; + for (var i = 0; i < object.typeProtos.length; ++i) { + if (typeof object.typeProtos[i] !== "object") + throw TypeError(".onnx.AttributeProto.typeProtos: object expected"); + message.typeProtos[i] = $root.onnx.TypeProto.fromObject(object.typeProtos[i]); + } + } + return message; + }; + + /** + * Creates a plain object from an AttributeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.AttributeProto + * @static + * @param {onnx.AttributeProto} message AttributeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + AttributeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.floats = []; + object.ints = []; + object.strings = []; + object.tensors = []; + object.graphs = []; + object.typeProtos = []; + object.sparseTensors = []; + } + if (options.defaults) { + object.name = ""; + object.f = 0; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.i = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.i = options.longs === String ? "0" : 0; + if (options.bytes === String) + object.s = ""; + else { + object.s = []; + if (options.bytes !== Array) + object.s = $util.newBuffer(object.s); + } + object.t = null; + object.g = null; + object.docString = ""; + object.tp = null; + object.type = options.enums === String ? "UNDEFINED" : 0; + object.refAttrName = ""; + object.sparseTensor = null; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.f != null && message.hasOwnProperty("f")) + object.f = options.json && !isFinite(message.f) ? String(message.f) : message.f; + if (message.i != null && message.hasOwnProperty("i")) + if (typeof message.i === "number") + object.i = options.longs === String ? String(message.i) : message.i; + else + object.i = options.longs === String ? $util.Long.prototype.toString.call(message.i) : options.longs === Number ? new $util.LongBits(message.i.low >>> 0, message.i.high >>> 0).toNumber() : message.i; + if (message.s != null && message.hasOwnProperty("s")) + object.s = options.bytes === String ? $util.base64.encode(message.s, 0, message.s.length) : options.bytes === Array ? Array.prototype.slice.call(message.s) : message.s; + if (message.t != null && message.hasOwnProperty("t")) + object.t = $root.onnx.TensorProto.toObject(message.t, options); + if (message.g != null && message.hasOwnProperty("g")) + object.g = $root.onnx.GraphProto.toObject(message.g, options); + if (message.floats && message.floats.length) { + object.floats = []; + for (var j = 0; j < message.floats.length; ++j) + object.floats[j] = options.json && !isFinite(message.floats[j]) ? String(message.floats[j]) : message.floats[j]; + } + if (message.ints && message.ints.length) { + object.ints = []; + for (var j = 0; j < message.ints.length; ++j) + if (typeof message.ints[j] === "number") + object.ints[j] = options.longs === String ? String(message.ints[j]) : message.ints[j]; + else + object.ints[j] = options.longs === String ? $util.Long.prototype.toString.call(message.ints[j]) : options.longs === Number ? new $util.LongBits(message.ints[j].low >>> 0, message.ints[j].high >>> 0).toNumber() : message.ints[j]; + } + if (message.strings && message.strings.length) { + object.strings = []; + for (var j = 0; j < message.strings.length; ++j) + object.strings[j] = options.bytes === String ? $util.base64.encode(message.strings[j], 0, message.strings[j].length) : options.bytes === Array ? Array.prototype.slice.call(message.strings[j]) : message.strings[j]; + } + if (message.tensors && message.tensors.length) { + object.tensors = []; + for (var j = 0; j < message.tensors.length; ++j) + object.tensors[j] = $root.onnx.TensorProto.toObject(message.tensors[j], options); + } + if (message.graphs && message.graphs.length) { + object.graphs = []; + for (var j = 0; j < message.graphs.length; ++j) + object.graphs[j] = $root.onnx.GraphProto.toObject(message.graphs[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.tp != null && message.hasOwnProperty("tp")) + object.tp = $root.onnx.TypeProto.toObject(message.tp, options); + if (message.typeProtos && message.typeProtos.length) { + object.typeProtos = []; + for (var j = 0; j < message.typeProtos.length; ++j) + object.typeProtos[j] = $root.onnx.TypeProto.toObject(message.typeProtos[j], options); + } + if (message.type != null && message.hasOwnProperty("type")) + object.type = options.enums === String ? $root.onnx.AttributeProto.AttributeType[message.type] === undefined ? message.type : $root.onnx.AttributeProto.AttributeType[message.type] : message.type; + if (message.refAttrName != null && message.hasOwnProperty("refAttrName")) + object.refAttrName = message.refAttrName; + if (message.sparseTensor != null && message.hasOwnProperty("sparseTensor")) + object.sparseTensor = $root.onnx.SparseTensorProto.toObject(message.sparseTensor, options); + if (message.sparseTensors && message.sparseTensors.length) { + object.sparseTensors = []; + for (var j = 0; j < message.sparseTensors.length; ++j) + object.sparseTensors[j] = $root.onnx.SparseTensorProto.toObject(message.sparseTensors[j], options); + } + return object; + }; + + /** + * Converts this AttributeProto to JSON. + * @function toJSON + * @memberof onnx.AttributeProto + * @instance + * @returns {Object.} JSON object + */ + AttributeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for AttributeProto + * @function getTypeUrl + * @memberof onnx.AttributeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + AttributeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.AttributeProto"; + }; + + /** + * AttributeType enum. + * @name onnx.AttributeProto.AttributeType + * @enum {number} + * @property {number} UNDEFINED=0 UNDEFINED value + * @property {number} FLOAT=1 FLOAT value + * @property {number} INT=2 INT value + * @property {number} STRING=3 STRING value + * @property {number} TENSOR=4 TENSOR value + * @property {number} GRAPH=5 GRAPH value + * @property {number} SPARSE_TENSOR=11 SPARSE_TENSOR value + * @property {number} TYPE_PROTO=13 TYPE_PROTO value + * @property {number} FLOATS=6 FLOATS value + * @property {number} INTS=7 INTS value + * @property {number} STRINGS=8 STRINGS value + * @property {number} TENSORS=9 TENSORS value + * @property {number} GRAPHS=10 GRAPHS value + * @property {number} SPARSE_TENSORS=12 SPARSE_TENSORS value + * @property {number} TYPE_PROTOS=14 TYPE_PROTOS value + */ + AttributeProto.AttributeType = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "UNDEFINED"] = 0; + values[valuesById[1] = "FLOAT"] = 1; + values[valuesById[2] = "INT"] = 2; + values[valuesById[3] = "STRING"] = 3; + values[valuesById[4] = "TENSOR"] = 4; + values[valuesById[5] = "GRAPH"] = 5; + values[valuesById[11] = "SPARSE_TENSOR"] = 11; + values[valuesById[13] = "TYPE_PROTO"] = 13; + values[valuesById[6] = "FLOATS"] = 6; + values[valuesById[7] = "INTS"] = 7; + values[valuesById[8] = "STRINGS"] = 8; + values[valuesById[9] = "TENSORS"] = 9; + values[valuesById[10] = "GRAPHS"] = 10; + values[valuesById[12] = "SPARSE_TENSORS"] = 12; + values[valuesById[14] = "TYPE_PROTOS"] = 14; + return values; + })(); + + return AttributeProto; + })(); + + onnx.ValueInfoProto = (function() { + + /** + * Properties of a ValueInfoProto. + * @memberof onnx + * @interface IValueInfoProto + * @property {string|null} [name] ValueInfoProto name + * @property {onnx.ITypeProto|null} [type] ValueInfoProto type + * @property {string|null} [docString] ValueInfoProto docString + */ + + /** + * Constructs a new ValueInfoProto. + * @memberof onnx + * @classdesc Represents a ValueInfoProto. + * @implements IValueInfoProto + * @constructor + * @param {onnx.IValueInfoProto=} [properties] Properties to set + */ + function ValueInfoProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * ValueInfoProto name. + * @member {string} name + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.name = ""; + + /** + * ValueInfoProto type. + * @member {onnx.ITypeProto|null|undefined} type + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.type = null; + + /** + * ValueInfoProto docString. + * @member {string} docString + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.docString = ""; + + /** + * Creates a new ValueInfoProto instance using the specified properties. + * @function create + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto=} [properties] Properties to set + * @returns {onnx.ValueInfoProto} ValueInfoProto instance + */ + ValueInfoProto.create = function create(properties) { + return new ValueInfoProto(properties); + }; + + /** + * Encodes the specified ValueInfoProto message. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. + * @function encode + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ValueInfoProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); + if (message.type != null && Object.hasOwnProperty.call(message, "type")) + $root.onnx.TypeProto.encode(message.type, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.docString); + return writer; + }; + + /** + * Encodes the specified ValueInfoProto message, length delimited. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ValueInfoProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.ValueInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.ValueInfoProto} ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ValueInfoProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.ValueInfoProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 2: { + message.type = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + case 3: { + message.docString = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.ValueInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.ValueInfoProto} ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ValueInfoProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a ValueInfoProto message. + * @function verify + * @memberof onnx.ValueInfoProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + ValueInfoProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.type != null && message.hasOwnProperty("type")) { + var error = $root.onnx.TypeProto.verify(message.type); + if (error) + return "type." + error; + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + return null; + }; + + /** + * Creates a ValueInfoProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.ValueInfoProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.ValueInfoProto} ValueInfoProto + */ + ValueInfoProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.ValueInfoProto) + return object; + var message = new $root.onnx.ValueInfoProto(); + if (object.name != null) + message.name = String(object.name); + if (object.type != null) { + if (typeof object.type !== "object") + throw TypeError(".onnx.ValueInfoProto.type: object expected"); + message.type = $root.onnx.TypeProto.fromObject(object.type); + } + if (object.docString != null) + message.docString = String(object.docString); + return message; + }; + + /** + * Creates a plain object from a ValueInfoProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.ValueInfoProto} message ValueInfoProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + ValueInfoProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.name = ""; + object.type = null; + object.docString = ""; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.type != null && message.hasOwnProperty("type")) + object.type = $root.onnx.TypeProto.toObject(message.type, options); + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + return object; + }; + + /** + * Converts this ValueInfoProto to JSON. + * @function toJSON + * @memberof onnx.ValueInfoProto + * @instance + * @returns {Object.} JSON object + */ + ValueInfoProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for ValueInfoProto + * @function getTypeUrl + * @memberof onnx.ValueInfoProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + ValueInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.ValueInfoProto"; + }; + + return ValueInfoProto; + })(); + + onnx.NodeProto = (function() { + + /** + * Properties of a NodeProto. + * @memberof onnx + * @interface INodeProto + * @property {Array.|null} [input] NodeProto input + * @property {Array.|null} [output] NodeProto output + * @property {string|null} [name] NodeProto name + * @property {string|null} [opType] NodeProto opType + * @property {string|null} [domain] NodeProto domain + * @property {Array.|null} [attribute] NodeProto attribute + * @property {string|null} [docString] NodeProto docString + */ + + /** + * Constructs a new NodeProto. + * @memberof onnx + * @classdesc Represents a NodeProto. + * @implements INodeProto + * @constructor + * @param {onnx.INodeProto=} [properties] Properties to set + */ + function NodeProto(properties) { + this.input = []; + this.output = []; + this.attribute = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * NodeProto input. + * @member {Array.} input + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.input = $util.emptyArray; + + /** + * NodeProto output. + * @member {Array.} output + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.output = $util.emptyArray; + + /** + * NodeProto name. + * @member {string} name + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.name = ""; + + /** + * NodeProto opType. + * @member {string} opType + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.opType = ""; + + /** + * NodeProto domain. + * @member {string} domain + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.domain = ""; + + /** + * NodeProto attribute. + * @member {Array.} attribute + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.attribute = $util.emptyArray; + + /** + * NodeProto docString. + * @member {string} docString + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.docString = ""; + + /** + * Creates a new NodeProto instance using the specified properties. + * @function create + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto=} [properties] Properties to set + * @returns {onnx.NodeProto} NodeProto instance + */ + NodeProto.create = function create(properties) { + return new NodeProto(properties); + }; + + /** + * Encodes the specified NodeProto message. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @function encode + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto} message NodeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + NodeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.input[i]); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.output[i]); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.name); + if (message.opType != null && Object.hasOwnProperty.call(message, "opType")) + writer.uint32(/* id 4, wireType 2 =*/34).string(message.opType); + if (message.attribute != null && message.attribute.length) + for (var i = 0; i < message.attribute.length; ++i) + $root.onnx.AttributeProto.encode(message.attribute[i], writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.docString); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 7, wireType 2 =*/58).string(message.domain); + return writer; + }; + + /** + * Encodes the specified NodeProto message, length delimited. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto} message NodeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + NodeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a NodeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.NodeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.NodeProto} NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + NodeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.NodeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.input && message.input.length)) + message.input = []; + message.input.push(reader.string()); + break; + } + case 2: { + if (!(message.output && message.output.length)) + message.output = []; + message.output.push(reader.string()); + break; + } + case 3: { + message.name = reader.string(); + break; + } + case 4: { + message.opType = reader.string(); + break; + } + case 7: { + message.domain = reader.string(); + break; + } + case 5: { + if (!(message.attribute && message.attribute.length)) + message.attribute = []; + message.attribute.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); + break; + } + case 6: { + message.docString = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a NodeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.NodeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.NodeProto} NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + NodeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a NodeProto message. + * @function verify + * @memberof onnx.NodeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + NodeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) + return "input: array expected"; + for (var i = 0; i < message.input.length; ++i) + if (!$util.isString(message.input[i])) + return "input: string[] expected"; + } + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) + return "output: array expected"; + for (var i = 0; i < message.output.length; ++i) + if (!$util.isString(message.output[i])) + return "output: string[] expected"; + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.opType != null && message.hasOwnProperty("opType")) + if (!$util.isString(message.opType)) + return "opType: string expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + if (message.attribute != null && message.hasOwnProperty("attribute")) { + if (!Array.isArray(message.attribute)) + return "attribute: array expected"; + for (var i = 0; i < message.attribute.length; ++i) { + var error = $root.onnx.AttributeProto.verify(message.attribute[i]); + if (error) + return "attribute." + error; + } + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + return null; + }; + + /** + * Creates a NodeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.NodeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.NodeProto} NodeProto + */ + NodeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.NodeProto) + return object; + var message = new $root.onnx.NodeProto(); + if (object.input) { + if (!Array.isArray(object.input)) + throw TypeError(".onnx.NodeProto.input: array expected"); + message.input = []; + for (var i = 0; i < object.input.length; ++i) + message.input[i] = String(object.input[i]); + } + if (object.output) { + if (!Array.isArray(object.output)) + throw TypeError(".onnx.NodeProto.output: array expected"); + message.output = []; + for (var i = 0; i < object.output.length; ++i) + message.output[i] = String(object.output[i]); + } + if (object.name != null) + message.name = String(object.name); + if (object.opType != null) + message.opType = String(object.opType); + if (object.domain != null) + message.domain = String(object.domain); + if (object.attribute) { + if (!Array.isArray(object.attribute)) + throw TypeError(".onnx.NodeProto.attribute: array expected"); + message.attribute = []; + for (var i = 0; i < object.attribute.length; ++i) { + if (typeof object.attribute[i] !== "object") + throw TypeError(".onnx.NodeProto.attribute: object expected"); + message.attribute[i] = $root.onnx.AttributeProto.fromObject(object.attribute[i]); + } + } + if (object.docString != null) + message.docString = String(object.docString); + return message; + }; + + /** + * Creates a plain object from a NodeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.NodeProto + * @static + * @param {onnx.NodeProto} message NodeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + NodeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.input = []; + object.output = []; + object.attribute = []; + } + if (options.defaults) { + object.name = ""; + object.opType = ""; + object.docString = ""; + object.domain = ""; + } + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = message.input[j]; + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = message.output[j]; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.opType != null && message.hasOwnProperty("opType")) + object.opType = message.opType; + if (message.attribute && message.attribute.length) { + object.attribute = []; + for (var j = 0; j < message.attribute.length; ++j) + object.attribute[j] = $root.onnx.AttributeProto.toObject(message.attribute[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + return object; + }; + + /** + * Converts this NodeProto to JSON. + * @function toJSON + * @memberof onnx.NodeProto + * @instance + * @returns {Object.} JSON object + */ + NodeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for NodeProto + * @function getTypeUrl + * @memberof onnx.NodeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + NodeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.NodeProto"; + }; + + return NodeProto; + })(); + + onnx.TrainingInfoProto = (function() { + + /** + * Properties of a TrainingInfoProto. + * @memberof onnx + * @interface ITrainingInfoProto + * @property {onnx.IGraphProto|null} [initialization] TrainingInfoProto initialization + * @property {onnx.IGraphProto|null} [algorithm] TrainingInfoProto algorithm + * @property {Array.|null} [initializationBinding] TrainingInfoProto initializationBinding + * @property {Array.|null} [updateBinding] TrainingInfoProto updateBinding + */ + + /** + * Constructs a new TrainingInfoProto. + * @memberof onnx + * @classdesc Represents a TrainingInfoProto. + * @implements ITrainingInfoProto + * @constructor + * @param {onnx.ITrainingInfoProto=} [properties] Properties to set + */ + function TrainingInfoProto(properties) { + this.initializationBinding = []; + this.updateBinding = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TrainingInfoProto initialization. + * @member {onnx.IGraphProto|null|undefined} initialization + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.initialization = null; + + /** + * TrainingInfoProto algorithm. + * @member {onnx.IGraphProto|null|undefined} algorithm + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.algorithm = null; + + /** + * TrainingInfoProto initializationBinding. + * @member {Array.} initializationBinding + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.initializationBinding = $util.emptyArray; + + /** + * TrainingInfoProto updateBinding. + * @member {Array.} updateBinding + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.updateBinding = $util.emptyArray; + + /** + * Creates a new TrainingInfoProto instance using the specified properties. + * @function create + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto=} [properties] Properties to set + * @returns {onnx.TrainingInfoProto} TrainingInfoProto instance + */ + TrainingInfoProto.create = function create(properties) { + return new TrainingInfoProto(properties); + }; + + /** + * Encodes the specified TrainingInfoProto message. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. + * @function encode + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TrainingInfoProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.initialization != null && Object.hasOwnProperty.call(message, "initialization")) + $root.onnx.GraphProto.encode(message.initialization, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.algorithm != null && Object.hasOwnProperty.call(message, "algorithm")) + $root.onnx.GraphProto.encode(message.algorithm, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.initializationBinding != null && message.initializationBinding.length) + for (var i = 0; i < message.initializationBinding.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.initializationBinding[i], writer.uint32(/* id 3, wireType 2 =*/26).fork()).ldelim(); + if (message.updateBinding != null && message.updateBinding.length) + for (var i = 0; i < message.updateBinding.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.updateBinding[i], writer.uint32(/* id 4, wireType 2 =*/34).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TrainingInfoProto message, length delimited. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TrainingInfoProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TrainingInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TrainingInfoProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TrainingInfoProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.initialization = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 2: { + message.algorithm = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 3: { + if (!(message.initializationBinding && message.initializationBinding.length)) + message.initializationBinding = []; + message.initializationBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 4: { + if (!(message.updateBinding && message.updateBinding.length)) + message.updateBinding = []; + message.updateBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TrainingInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TrainingInfoProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TrainingInfoProto message. + * @function verify + * @memberof onnx.TrainingInfoProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TrainingInfoProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.initialization != null && message.hasOwnProperty("initialization")) { + var error = $root.onnx.GraphProto.verify(message.initialization); + if (error) + return "initialization." + error; + } + if (message.algorithm != null && message.hasOwnProperty("algorithm")) { + var error = $root.onnx.GraphProto.verify(message.algorithm); + if (error) + return "algorithm." + error; + } + if (message.initializationBinding != null && message.hasOwnProperty("initializationBinding")) { + if (!Array.isArray(message.initializationBinding)) + return "initializationBinding: array expected"; + for (var i = 0; i < message.initializationBinding.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.initializationBinding[i]); + if (error) + return "initializationBinding." + error; + } + } + if (message.updateBinding != null && message.hasOwnProperty("updateBinding")) { + if (!Array.isArray(message.updateBinding)) + return "updateBinding: array expected"; + for (var i = 0; i < message.updateBinding.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.updateBinding[i]); + if (error) + return "updateBinding." + error; + } + } + return null; + }; + + /** + * Creates a TrainingInfoProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TrainingInfoProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + */ + TrainingInfoProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TrainingInfoProto) + return object; + var message = new $root.onnx.TrainingInfoProto(); + if (object.initialization != null) { + if (typeof object.initialization !== "object") + throw TypeError(".onnx.TrainingInfoProto.initialization: object expected"); + message.initialization = $root.onnx.GraphProto.fromObject(object.initialization); + } + if (object.algorithm != null) { + if (typeof object.algorithm !== "object") + throw TypeError(".onnx.TrainingInfoProto.algorithm: object expected"); + message.algorithm = $root.onnx.GraphProto.fromObject(object.algorithm); + } + if (object.initializationBinding) { + if (!Array.isArray(object.initializationBinding)) + throw TypeError(".onnx.TrainingInfoProto.initializationBinding: array expected"); + message.initializationBinding = []; + for (var i = 0; i < object.initializationBinding.length; ++i) { + if (typeof object.initializationBinding[i] !== "object") + throw TypeError(".onnx.TrainingInfoProto.initializationBinding: object expected"); + message.initializationBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.initializationBinding[i]); + } + } + if (object.updateBinding) { + if (!Array.isArray(object.updateBinding)) + throw TypeError(".onnx.TrainingInfoProto.updateBinding: array expected"); + message.updateBinding = []; + for (var i = 0; i < object.updateBinding.length; ++i) { + if (typeof object.updateBinding[i] !== "object") + throw TypeError(".onnx.TrainingInfoProto.updateBinding: object expected"); + message.updateBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.updateBinding[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TrainingInfoProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.TrainingInfoProto} message TrainingInfoProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TrainingInfoProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.initializationBinding = []; + object.updateBinding = []; + } + if (options.defaults) { + object.initialization = null; + object.algorithm = null; + } + if (message.initialization != null && message.hasOwnProperty("initialization")) + object.initialization = $root.onnx.GraphProto.toObject(message.initialization, options); + if (message.algorithm != null && message.hasOwnProperty("algorithm")) + object.algorithm = $root.onnx.GraphProto.toObject(message.algorithm, options); + if (message.initializationBinding && message.initializationBinding.length) { + object.initializationBinding = []; + for (var j = 0; j < message.initializationBinding.length; ++j) + object.initializationBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.initializationBinding[j], options); + } + if (message.updateBinding && message.updateBinding.length) { + object.updateBinding = []; + for (var j = 0; j < message.updateBinding.length; ++j) + object.updateBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.updateBinding[j], options); + } + return object; + }; + + /** + * Converts this TrainingInfoProto to JSON. + * @function toJSON + * @memberof onnx.TrainingInfoProto + * @instance + * @returns {Object.} JSON object + */ + TrainingInfoProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TrainingInfoProto + * @function getTypeUrl + * @memberof onnx.TrainingInfoProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TrainingInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TrainingInfoProto"; + }; + + return TrainingInfoProto; + })(); + + onnx.ModelProto = (function() { + + /** + * Properties of a ModelProto. + * @memberof onnx + * @interface IModelProto + * @property {number|Long|null} [irVersion] ModelProto irVersion + * @property {Array.|null} [opsetImport] ModelProto opsetImport + * @property {string|null} [producerName] ModelProto producerName + * @property {string|null} [producerVersion] ModelProto producerVersion + * @property {string|null} [domain] ModelProto domain + * @property {number|Long|null} [modelVersion] ModelProto modelVersion + * @property {string|null} [docString] ModelProto docString + * @property {onnx.IGraphProto|null} [graph] ModelProto graph + * @property {Array.|null} [metadataProps] ModelProto metadataProps + * @property {Array.|null} [trainingInfo] ModelProto trainingInfo + * @property {Array.|null} [functions] ModelProto functions + */ + + /** + * Constructs a new ModelProto. + * @memberof onnx + * @classdesc Represents a ModelProto. + * @implements IModelProto + * @constructor + * @param {onnx.IModelProto=} [properties] Properties to set + */ + function ModelProto(properties) { + this.opsetImport = []; + this.metadataProps = []; + this.trainingInfo = []; + this.functions = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * ModelProto irVersion. + * @member {number|Long} irVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.irVersion = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * ModelProto opsetImport. + * @member {Array.} opsetImport + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.opsetImport = $util.emptyArray; + + /** + * ModelProto producerName. + * @member {string} producerName + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.producerName = ""; + + /** + * ModelProto producerVersion. + * @member {string} producerVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.producerVersion = ""; + + /** + * ModelProto domain. + * @member {string} domain + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.domain = ""; + + /** + * ModelProto modelVersion. + * @member {number|Long} modelVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.modelVersion = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * ModelProto docString. + * @member {string} docString + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.docString = ""; + + /** + * ModelProto graph. + * @member {onnx.IGraphProto|null|undefined} graph + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.graph = null; + + /** + * ModelProto metadataProps. + * @member {Array.} metadataProps + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.metadataProps = $util.emptyArray; + + /** + * ModelProto trainingInfo. + * @member {Array.} trainingInfo + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.trainingInfo = $util.emptyArray; + + /** + * ModelProto functions. + * @member {Array.} functions + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.functions = $util.emptyArray; + + /** + * Creates a new ModelProto instance using the specified properties. + * @function create + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto=} [properties] Properties to set + * @returns {onnx.ModelProto} ModelProto instance + */ + ModelProto.create = function create(properties) { + return new ModelProto(properties); + }; + + /** + * Encodes the specified ModelProto message. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @function encode + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto} message ModelProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ModelProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.irVersion != null && Object.hasOwnProperty.call(message, "irVersion")) + writer.uint32(/* id 1, wireType 0 =*/8).int64(message.irVersion); + if (message.producerName != null && Object.hasOwnProperty.call(message, "producerName")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.producerName); + if (message.producerVersion != null && Object.hasOwnProperty.call(message, "producerVersion")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.producerVersion); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 4, wireType 2 =*/34).string(message.domain); + if (message.modelVersion != null && Object.hasOwnProperty.call(message, "modelVersion")) + writer.uint32(/* id 5, wireType 0 =*/40).int64(message.modelVersion); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.docString); + if (message.graph != null && Object.hasOwnProperty.call(message, "graph")) + $root.onnx.GraphProto.encode(message.graph, writer.uint32(/* id 7, wireType 2 =*/58).fork()).ldelim(); + if (message.opsetImport != null && message.opsetImport.length) + for (var i = 0; i < message.opsetImport.length; ++i) + $root.onnx.OperatorSetIdProto.encode(message.opsetImport[i], writer.uint32(/* id 8, wireType 2 =*/66).fork()).ldelim(); + if (message.metadataProps != null && message.metadataProps.length) + for (var i = 0; i < message.metadataProps.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.metadataProps[i], writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); + if (message.trainingInfo != null && message.trainingInfo.length) + for (var i = 0; i < message.trainingInfo.length; ++i) + $root.onnx.TrainingInfoProto.encode(message.trainingInfo[i], writer.uint32(/* id 20, wireType 2 =*/162).fork()).ldelim(); + if (message.functions != null && message.functions.length) + for (var i = 0; i < message.functions.length; ++i) + $root.onnx.FunctionProto.encode(message.functions[i], writer.uint32(/* id 25, wireType 2 =*/202).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified ModelProto message, length delimited. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto} message ModelProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ModelProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a ModelProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.ModelProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.ModelProto} ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ModelProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.ModelProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.irVersion = reader.int64(); + break; + } + case 8: { + if (!(message.opsetImport && message.opsetImport.length)) + message.opsetImport = []; + message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); + break; + } + case 2: { + message.producerName = reader.string(); + break; + } + case 3: { + message.producerVersion = reader.string(); + break; + } + case 4: { + message.domain = reader.string(); + break; + } + case 5: { + message.modelVersion = reader.int64(); + break; + } + case 6: { + message.docString = reader.string(); + break; + } + case 7: { + message.graph = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 14: { + if (!(message.metadataProps && message.metadataProps.length)) + message.metadataProps = []; + message.metadataProps.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 20: { + if (!(message.trainingInfo && message.trainingInfo.length)) + message.trainingInfo = []; + message.trainingInfo.push($root.onnx.TrainingInfoProto.decode(reader, reader.uint32())); + break; + } + case 25: { + if (!(message.functions && message.functions.length)) + message.functions = []; + message.functions.push($root.onnx.FunctionProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a ModelProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.ModelProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.ModelProto} ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ModelProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a ModelProto message. + * @function verify + * @memberof onnx.ModelProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + ModelProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.irVersion != null && message.hasOwnProperty("irVersion")) + if (!$util.isInteger(message.irVersion) && !(message.irVersion && $util.isInteger(message.irVersion.low) && $util.isInteger(message.irVersion.high))) + return "irVersion: integer|Long expected"; + if (message.opsetImport != null && message.hasOwnProperty("opsetImport")) { + if (!Array.isArray(message.opsetImport)) + return "opsetImport: array expected"; + for (var i = 0; i < message.opsetImport.length; ++i) { + var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); + if (error) + return "opsetImport." + error; + } + } + if (message.producerName != null && message.hasOwnProperty("producerName")) + if (!$util.isString(message.producerName)) + return "producerName: string expected"; + if (message.producerVersion != null && message.hasOwnProperty("producerVersion")) + if (!$util.isString(message.producerVersion)) + return "producerVersion: string expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + if (message.modelVersion != null && message.hasOwnProperty("modelVersion")) + if (!$util.isInteger(message.modelVersion) && !(message.modelVersion && $util.isInteger(message.modelVersion.low) && $util.isInteger(message.modelVersion.high))) + return "modelVersion: integer|Long expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.graph != null && message.hasOwnProperty("graph")) { + var error = $root.onnx.GraphProto.verify(message.graph); + if (error) + return "graph." + error; + } + if (message.metadataProps != null && message.hasOwnProperty("metadataProps")) { + if (!Array.isArray(message.metadataProps)) + return "metadataProps: array expected"; + for (var i = 0; i < message.metadataProps.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.metadataProps[i]); + if (error) + return "metadataProps." + error; + } + } + if (message.trainingInfo != null && message.hasOwnProperty("trainingInfo")) { + if (!Array.isArray(message.trainingInfo)) + return "trainingInfo: array expected"; + for (var i = 0; i < message.trainingInfo.length; ++i) { + var error = $root.onnx.TrainingInfoProto.verify(message.trainingInfo[i]); + if (error) + return "trainingInfo." + error; + } + } + if (message.functions != null && message.hasOwnProperty("functions")) { + if (!Array.isArray(message.functions)) + return "functions: array expected"; + for (var i = 0; i < message.functions.length; ++i) { + var error = $root.onnx.FunctionProto.verify(message.functions[i]); + if (error) + return "functions." + error; + } + } + return null; + }; + + /** + * Creates a ModelProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.ModelProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.ModelProto} ModelProto + */ + ModelProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.ModelProto) + return object; + var message = new $root.onnx.ModelProto(); + if (object.irVersion != null) + if ($util.Long) + (message.irVersion = $util.Long.fromValue(object.irVersion)).unsigned = false; + else if (typeof object.irVersion === "string") + message.irVersion = parseInt(object.irVersion, 10); + else if (typeof object.irVersion === "number") + message.irVersion = object.irVersion; + else if (typeof object.irVersion === "object") + message.irVersion = new $util.LongBits(object.irVersion.low >>> 0, object.irVersion.high >>> 0).toNumber(); + if (object.opsetImport) { + if (!Array.isArray(object.opsetImport)) + throw TypeError(".onnx.ModelProto.opsetImport: array expected"); + message.opsetImport = []; + for (var i = 0; i < object.opsetImport.length; ++i) { + if (typeof object.opsetImport[i] !== "object") + throw TypeError(".onnx.ModelProto.opsetImport: object expected"); + message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); + } + } + if (object.producerName != null) + message.producerName = String(object.producerName); + if (object.producerVersion != null) + message.producerVersion = String(object.producerVersion); + if (object.domain != null) + message.domain = String(object.domain); + if (object.modelVersion != null) + if ($util.Long) + (message.modelVersion = $util.Long.fromValue(object.modelVersion)).unsigned = false; + else if (typeof object.modelVersion === "string") + message.modelVersion = parseInt(object.modelVersion, 10); + else if (typeof object.modelVersion === "number") + message.modelVersion = object.modelVersion; + else if (typeof object.modelVersion === "object") + message.modelVersion = new $util.LongBits(object.modelVersion.low >>> 0, object.modelVersion.high >>> 0).toNumber(); + if (object.docString != null) + message.docString = String(object.docString); + if (object.graph != null) { + if (typeof object.graph !== "object") + throw TypeError(".onnx.ModelProto.graph: object expected"); + message.graph = $root.onnx.GraphProto.fromObject(object.graph); + } + if (object.metadataProps) { + if (!Array.isArray(object.metadataProps)) + throw TypeError(".onnx.ModelProto.metadataProps: array expected"); + message.metadataProps = []; + for (var i = 0; i < object.metadataProps.length; ++i) { + if (typeof object.metadataProps[i] !== "object") + throw TypeError(".onnx.ModelProto.metadataProps: object expected"); + message.metadataProps[i] = $root.onnx.StringStringEntryProto.fromObject(object.metadataProps[i]); + } + } + if (object.trainingInfo) { + if (!Array.isArray(object.trainingInfo)) + throw TypeError(".onnx.ModelProto.trainingInfo: array expected"); + message.trainingInfo = []; + for (var i = 0; i < object.trainingInfo.length; ++i) { + if (typeof object.trainingInfo[i] !== "object") + throw TypeError(".onnx.ModelProto.trainingInfo: object expected"); + message.trainingInfo[i] = $root.onnx.TrainingInfoProto.fromObject(object.trainingInfo[i]); + } + } + if (object.functions) { + if (!Array.isArray(object.functions)) + throw TypeError(".onnx.ModelProto.functions: array expected"); + message.functions = []; + for (var i = 0; i < object.functions.length; ++i) { + if (typeof object.functions[i] !== "object") + throw TypeError(".onnx.ModelProto.functions: object expected"); + message.functions[i] = $root.onnx.FunctionProto.fromObject(object.functions[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a ModelProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.ModelProto + * @static + * @param {onnx.ModelProto} message ModelProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + ModelProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.opsetImport = []; + object.metadataProps = []; + object.trainingInfo = []; + object.functions = []; + } + if (options.defaults) { + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.irVersion = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.irVersion = options.longs === String ? "0" : 0; + object.producerName = ""; + object.producerVersion = ""; + object.domain = ""; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.modelVersion = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.modelVersion = options.longs === String ? "0" : 0; + object.docString = ""; + object.graph = null; + } + if (message.irVersion != null && message.hasOwnProperty("irVersion")) + if (typeof message.irVersion === "number") + object.irVersion = options.longs === String ? String(message.irVersion) : message.irVersion; + else + object.irVersion = options.longs === String ? $util.Long.prototype.toString.call(message.irVersion) : options.longs === Number ? new $util.LongBits(message.irVersion.low >>> 0, message.irVersion.high >>> 0).toNumber() : message.irVersion; + if (message.producerName != null && message.hasOwnProperty("producerName")) + object.producerName = message.producerName; + if (message.producerVersion != null && message.hasOwnProperty("producerVersion")) + object.producerVersion = message.producerVersion; + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.modelVersion != null && message.hasOwnProperty("modelVersion")) + if (typeof message.modelVersion === "number") + object.modelVersion = options.longs === String ? String(message.modelVersion) : message.modelVersion; + else + object.modelVersion = options.longs === String ? $util.Long.prototype.toString.call(message.modelVersion) : options.longs === Number ? new $util.LongBits(message.modelVersion.low >>> 0, message.modelVersion.high >>> 0).toNumber() : message.modelVersion; + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.graph != null && message.hasOwnProperty("graph")) + object.graph = $root.onnx.GraphProto.toObject(message.graph, options); + if (message.opsetImport && message.opsetImport.length) { + object.opsetImport = []; + for (var j = 0; j < message.opsetImport.length; ++j) + object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); + } + if (message.metadataProps && message.metadataProps.length) { + object.metadataProps = []; + for (var j = 0; j < message.metadataProps.length; ++j) + object.metadataProps[j] = $root.onnx.StringStringEntryProto.toObject(message.metadataProps[j], options); + } + if (message.trainingInfo && message.trainingInfo.length) { + object.trainingInfo = []; + for (var j = 0; j < message.trainingInfo.length; ++j) + object.trainingInfo[j] = $root.onnx.TrainingInfoProto.toObject(message.trainingInfo[j], options); + } + if (message.functions && message.functions.length) { + object.functions = []; + for (var j = 0; j < message.functions.length; ++j) + object.functions[j] = $root.onnx.FunctionProto.toObject(message.functions[j], options); + } + return object; + }; + + /** + * Converts this ModelProto to JSON. + * @function toJSON + * @memberof onnx.ModelProto + * @instance + * @returns {Object.} JSON object + */ + ModelProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for ModelProto + * @function getTypeUrl + * @memberof onnx.ModelProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + ModelProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.ModelProto"; + }; + + return ModelProto; + })(); + + onnx.StringStringEntryProto = (function() { + + /** + * Properties of a StringStringEntryProto. + * @memberof onnx + * @interface IStringStringEntryProto + * @property {string|null} [key] StringStringEntryProto key + * @property {string|null} [value] StringStringEntryProto value + */ + + /** + * Constructs a new StringStringEntryProto. + * @memberof onnx + * @classdesc Represents a StringStringEntryProto. + * @implements IStringStringEntryProto + * @constructor + * @param {onnx.IStringStringEntryProto=} [properties] Properties to set + */ + function StringStringEntryProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * StringStringEntryProto key. + * @member {string} key + * @memberof onnx.StringStringEntryProto + * @instance + */ + StringStringEntryProto.prototype.key = ""; + + /** + * StringStringEntryProto value. + * @member {string} value + * @memberof onnx.StringStringEntryProto + * @instance + */ + StringStringEntryProto.prototype.value = ""; + + /** + * Creates a new StringStringEntryProto instance using the specified properties. + * @function create + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto=} [properties] Properties to set + * @returns {onnx.StringStringEntryProto} StringStringEntryProto instance + */ + StringStringEntryProto.create = function create(properties) { + return new StringStringEntryProto(properties); + }; + + /** + * Encodes the specified StringStringEntryProto message. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. + * @function encode + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + StringStringEntryProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.key != null && Object.hasOwnProperty.call(message, "key")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.key); + if (message.value != null && Object.hasOwnProperty.call(message, "value")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.value); + return writer; + }; + + /** + * Encodes the specified StringStringEntryProto message, length delimited. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + StringStringEntryProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.StringStringEntryProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + StringStringEntryProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.StringStringEntryProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.key = reader.string(); + break; + } + case 2: { + message.value = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.StringStringEntryProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + StringStringEntryProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a StringStringEntryProto message. + * @function verify + * @memberof onnx.StringStringEntryProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + StringStringEntryProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.key != null && message.hasOwnProperty("key")) + if (!$util.isString(message.key)) + return "key: string expected"; + if (message.value != null && message.hasOwnProperty("value")) + if (!$util.isString(message.value)) + return "value: string expected"; + return null; + }; + + /** + * Creates a StringStringEntryProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.StringStringEntryProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + */ + StringStringEntryProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.StringStringEntryProto) + return object; + var message = new $root.onnx.StringStringEntryProto(); + if (object.key != null) + message.key = String(object.key); + if (object.value != null) + message.value = String(object.value); + return message; + }; + + /** + * Creates a plain object from a StringStringEntryProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.StringStringEntryProto} message StringStringEntryProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + StringStringEntryProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.key = ""; + object.value = ""; + } + if (message.key != null && message.hasOwnProperty("key")) + object.key = message.key; + if (message.value != null && message.hasOwnProperty("value")) + object.value = message.value; + return object; + }; + + /** + * Converts this StringStringEntryProto to JSON. + * @function toJSON + * @memberof onnx.StringStringEntryProto + * @instance + * @returns {Object.} JSON object + */ + StringStringEntryProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for StringStringEntryProto + * @function getTypeUrl + * @memberof onnx.StringStringEntryProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + StringStringEntryProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.StringStringEntryProto"; + }; + + return StringStringEntryProto; + })(); + + onnx.TensorAnnotation = (function() { + + /** + * Properties of a TensorAnnotation. + * @memberof onnx + * @interface ITensorAnnotation + * @property {string|null} [tensorName] TensorAnnotation tensorName + * @property {Array.|null} [quantParameterTensorNames] TensorAnnotation quantParameterTensorNames + */ + + /** + * Constructs a new TensorAnnotation. + * @memberof onnx + * @classdesc Represents a TensorAnnotation. + * @implements ITensorAnnotation + * @constructor + * @param {onnx.ITensorAnnotation=} [properties] Properties to set + */ + function TensorAnnotation(properties) { + this.quantParameterTensorNames = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorAnnotation tensorName. + * @member {string} tensorName + * @memberof onnx.TensorAnnotation + * @instance + */ + TensorAnnotation.prototype.tensorName = ""; + + /** + * TensorAnnotation quantParameterTensorNames. + * @member {Array.} quantParameterTensorNames + * @memberof onnx.TensorAnnotation + * @instance + */ + TensorAnnotation.prototype.quantParameterTensorNames = $util.emptyArray; + + /** + * Creates a new TensorAnnotation instance using the specified properties. + * @function create + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation=} [properties] Properties to set + * @returns {onnx.TensorAnnotation} TensorAnnotation instance + */ + TensorAnnotation.create = function create(properties) { + return new TensorAnnotation(properties); + }; + + /** + * Encodes the specified TensorAnnotation message. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. + * @function encode + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorAnnotation.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.tensorName != null && Object.hasOwnProperty.call(message, "tensorName")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.tensorName); + if (message.quantParameterTensorNames != null && message.quantParameterTensorNames.length) + for (var i = 0; i < message.quantParameterTensorNames.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.quantParameterTensorNames[i], writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TensorAnnotation message, length delimited. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorAnnotation.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorAnnotation + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorAnnotation} TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorAnnotation.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorAnnotation(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.tensorName = reader.string(); + break; + } + case 2: { + if (!(message.quantParameterTensorNames && message.quantParameterTensorNames.length)) + message.quantParameterTensorNames = []; + message.quantParameterTensorNames.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorAnnotation + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorAnnotation} TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorAnnotation.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorAnnotation message. + * @function verify + * @memberof onnx.TensorAnnotation + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorAnnotation.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.tensorName != null && message.hasOwnProperty("tensorName")) + if (!$util.isString(message.tensorName)) + return "tensorName: string expected"; + if (message.quantParameterTensorNames != null && message.hasOwnProperty("quantParameterTensorNames")) { + if (!Array.isArray(message.quantParameterTensorNames)) + return "quantParameterTensorNames: array expected"; + for (var i = 0; i < message.quantParameterTensorNames.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.quantParameterTensorNames[i]); + if (error) + return "quantParameterTensorNames." + error; + } + } + return null; + }; + + /** + * Creates a TensorAnnotation message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorAnnotation + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorAnnotation} TensorAnnotation + */ + TensorAnnotation.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorAnnotation) + return object; + var message = new $root.onnx.TensorAnnotation(); + if (object.tensorName != null) + message.tensorName = String(object.tensorName); + if (object.quantParameterTensorNames) { + if (!Array.isArray(object.quantParameterTensorNames)) + throw TypeError(".onnx.TensorAnnotation.quantParameterTensorNames: array expected"); + message.quantParameterTensorNames = []; + for (var i = 0; i < object.quantParameterTensorNames.length; ++i) { + if (typeof object.quantParameterTensorNames[i] !== "object") + throw TypeError(".onnx.TensorAnnotation.quantParameterTensorNames: object expected"); + message.quantParameterTensorNames[i] = $root.onnx.StringStringEntryProto.fromObject(object.quantParameterTensorNames[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TensorAnnotation message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.TensorAnnotation} message TensorAnnotation + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorAnnotation.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) + object.quantParameterTensorNames = []; + if (options.defaults) + object.tensorName = ""; + if (message.tensorName != null && message.hasOwnProperty("tensorName")) + object.tensorName = message.tensorName; + if (message.quantParameterTensorNames && message.quantParameterTensorNames.length) { + object.quantParameterTensorNames = []; + for (var j = 0; j < message.quantParameterTensorNames.length; ++j) + object.quantParameterTensorNames[j] = $root.onnx.StringStringEntryProto.toObject(message.quantParameterTensorNames[j], options); + } + return object; + }; + + /** + * Converts this TensorAnnotation to JSON. + * @function toJSON + * @memberof onnx.TensorAnnotation + * @instance + * @returns {Object.} JSON object + */ + TensorAnnotation.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorAnnotation + * @function getTypeUrl + * @memberof onnx.TensorAnnotation + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorAnnotation.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorAnnotation"; + }; + + return TensorAnnotation; + })(); + + onnx.GraphProto = (function() { + + /** + * Properties of a GraphProto. + * @memberof onnx + * @interface IGraphProto + * @property {Array.|null} [node] GraphProto node + * @property {string|null} [name] GraphProto name + * @property {Array.|null} [initializer] GraphProto initializer + * @property {Array.|null} [sparseInitializer] GraphProto sparseInitializer + * @property {string|null} [docString] GraphProto docString + * @property {Array.|null} [input] GraphProto input + * @property {Array.|null} [output] GraphProto output + * @property {Array.|null} [valueInfo] GraphProto valueInfo + * @property {Array.|null} [quantizationAnnotation] GraphProto quantizationAnnotation + */ + + /** + * Constructs a new GraphProto. + * @memberof onnx + * @classdesc Represents a GraphProto. + * @implements IGraphProto + * @constructor + * @param {onnx.IGraphProto=} [properties] Properties to set + */ + function GraphProto(properties) { + this.node = []; + this.initializer = []; + this.sparseInitializer = []; + this.input = []; + this.output = []; + this.valueInfo = []; + this.quantizationAnnotation = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * GraphProto node. + * @member {Array.} node + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.node = $util.emptyArray; + + /** + * GraphProto name. + * @member {string} name + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.name = ""; + + /** + * GraphProto initializer. + * @member {Array.} initializer + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.initializer = $util.emptyArray; + + /** + * GraphProto sparseInitializer. + * @member {Array.} sparseInitializer + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.sparseInitializer = $util.emptyArray; + + /** + * GraphProto docString. + * @member {string} docString + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.docString = ""; + + /** + * GraphProto input. + * @member {Array.} input + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.input = $util.emptyArray; + + /** + * GraphProto output. + * @member {Array.} output + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.output = $util.emptyArray; + + /** + * GraphProto valueInfo. + * @member {Array.} valueInfo + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.valueInfo = $util.emptyArray; + + /** + * GraphProto quantizationAnnotation. + * @member {Array.} quantizationAnnotation + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.quantizationAnnotation = $util.emptyArray; + + /** + * Creates a new GraphProto instance using the specified properties. + * @function create + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto=} [properties] Properties to set + * @returns {onnx.GraphProto} GraphProto instance + */ + GraphProto.create = function create(properties) { + return new GraphProto(properties); + }; + + /** + * Encodes the specified GraphProto message. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @function encode + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto} message GraphProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + GraphProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.node != null && message.node.length) + for (var i = 0; i < message.node.length; ++i) + $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.name); + if (message.initializer != null && message.initializer.length) + for (var i = 0; i < message.initializer.length; ++i) + $root.onnx.TensorProto.encode(message.initializer[i], writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 10, wireType 2 =*/82).string(message.docString); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + $root.onnx.ValueInfoProto.encode(message.input[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + $root.onnx.ValueInfoProto.encode(message.output[i], writer.uint32(/* id 12, wireType 2 =*/98).fork()).ldelim(); + if (message.valueInfo != null && message.valueInfo.length) + for (var i = 0; i < message.valueInfo.length; ++i) + $root.onnx.ValueInfoProto.encode(message.valueInfo[i], writer.uint32(/* id 13, wireType 2 =*/106).fork()).ldelim(); + if (message.quantizationAnnotation != null && message.quantizationAnnotation.length) + for (var i = 0; i < message.quantizationAnnotation.length; ++i) + $root.onnx.TensorAnnotation.encode(message.quantizationAnnotation[i], writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); + if (message.sparseInitializer != null && message.sparseInitializer.length) + for (var i = 0; i < message.sparseInitializer.length; ++i) + $root.onnx.SparseTensorProto.encode(message.sparseInitializer[i], writer.uint32(/* id 15, wireType 2 =*/122).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified GraphProto message, length delimited. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto} message GraphProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + GraphProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a GraphProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.GraphProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.GraphProto} GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + GraphProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.GraphProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.node && message.node.length)) + message.node = []; + message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); + break; + } + case 2: { + message.name = reader.string(); + break; + } + case 5: { + if (!(message.initializer && message.initializer.length)) + message.initializer = []; + message.initializer.push($root.onnx.TensorProto.decode(reader, reader.uint32())); + break; + } + case 15: { + if (!(message.sparseInitializer && message.sparseInitializer.length)) + message.sparseInitializer = []; + message.sparseInitializer.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); + break; + } + case 10: { + message.docString = reader.string(); + break; + } + case 11: { + if (!(message.input && message.input.length)) + message.input = []; + message.input.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 12: { + if (!(message.output && message.output.length)) + message.output = []; + message.output.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 13: { + if (!(message.valueInfo && message.valueInfo.length)) + message.valueInfo = []; + message.valueInfo.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 14: { + if (!(message.quantizationAnnotation && message.quantizationAnnotation.length)) + message.quantizationAnnotation = []; + message.quantizationAnnotation.push($root.onnx.TensorAnnotation.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a GraphProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.GraphProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.GraphProto} GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + GraphProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a GraphProto message. + * @function verify + * @memberof onnx.GraphProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + GraphProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.node != null && message.hasOwnProperty("node")) { + if (!Array.isArray(message.node)) + return "node: array expected"; + for (var i = 0; i < message.node.length; ++i) { + var error = $root.onnx.NodeProto.verify(message.node[i]); + if (error) + return "node." + error; + } + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.initializer != null && message.hasOwnProperty("initializer")) { + if (!Array.isArray(message.initializer)) + return "initializer: array expected"; + for (var i = 0; i < message.initializer.length; ++i) { + var error = $root.onnx.TensorProto.verify(message.initializer[i]); + if (error) + return "initializer." + error; + } + } + if (message.sparseInitializer != null && message.hasOwnProperty("sparseInitializer")) { + if (!Array.isArray(message.sparseInitializer)) + return "sparseInitializer: array expected"; + for (var i = 0; i < message.sparseInitializer.length; ++i) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseInitializer[i]); + if (error) + return "sparseInitializer." + error; + } + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) + return "input: array expected"; + for (var i = 0; i < message.input.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.input[i]); + if (error) + return "input." + error; + } + } + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) + return "output: array expected"; + for (var i = 0; i < message.output.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.output[i]); + if (error) + return "output." + error; + } + } + if (message.valueInfo != null && message.hasOwnProperty("valueInfo")) { + if (!Array.isArray(message.valueInfo)) + return "valueInfo: array expected"; + for (var i = 0; i < message.valueInfo.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.valueInfo[i]); + if (error) + return "valueInfo." + error; + } + } + if (message.quantizationAnnotation != null && message.hasOwnProperty("quantizationAnnotation")) { + if (!Array.isArray(message.quantizationAnnotation)) + return "quantizationAnnotation: array expected"; + for (var i = 0; i < message.quantizationAnnotation.length; ++i) { + var error = $root.onnx.TensorAnnotation.verify(message.quantizationAnnotation[i]); + if (error) + return "quantizationAnnotation." + error; + } + } + return null; + }; + + /** + * Creates a GraphProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.GraphProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.GraphProto} GraphProto + */ + GraphProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.GraphProto) + return object; + var message = new $root.onnx.GraphProto(); + if (object.node) { + if (!Array.isArray(object.node)) + throw TypeError(".onnx.GraphProto.node: array expected"); + message.node = []; + for (var i = 0; i < object.node.length; ++i) { + if (typeof object.node[i] !== "object") + throw TypeError(".onnx.GraphProto.node: object expected"); + message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); + } + } + if (object.name != null) + message.name = String(object.name); + if (object.initializer) { + if (!Array.isArray(object.initializer)) + throw TypeError(".onnx.GraphProto.initializer: array expected"); + message.initializer = []; + for (var i = 0; i < object.initializer.length; ++i) { + if (typeof object.initializer[i] !== "object") + throw TypeError(".onnx.GraphProto.initializer: object expected"); + message.initializer[i] = $root.onnx.TensorProto.fromObject(object.initializer[i]); + } + } + if (object.sparseInitializer) { + if (!Array.isArray(object.sparseInitializer)) + throw TypeError(".onnx.GraphProto.sparseInitializer: array expected"); + message.sparseInitializer = []; + for (var i = 0; i < object.sparseInitializer.length; ++i) { + if (typeof object.sparseInitializer[i] !== "object") + throw TypeError(".onnx.GraphProto.sparseInitializer: object expected"); + message.sparseInitializer[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseInitializer[i]); + } + } + if (object.docString != null) + message.docString = String(object.docString); + if (object.input) { + if (!Array.isArray(object.input)) + throw TypeError(".onnx.GraphProto.input: array expected"); + message.input = []; + for (var i = 0; i < object.input.length; ++i) { + if (typeof object.input[i] !== "object") + throw TypeError(".onnx.GraphProto.input: object expected"); + message.input[i] = $root.onnx.ValueInfoProto.fromObject(object.input[i]); + } + } + if (object.output) { + if (!Array.isArray(object.output)) + throw TypeError(".onnx.GraphProto.output: array expected"); + message.output = []; + for (var i = 0; i < object.output.length; ++i) { + if (typeof object.output[i] !== "object") + throw TypeError(".onnx.GraphProto.output: object expected"); + message.output[i] = $root.onnx.ValueInfoProto.fromObject(object.output[i]); + } + } + if (object.valueInfo) { + if (!Array.isArray(object.valueInfo)) + throw TypeError(".onnx.GraphProto.valueInfo: array expected"); + message.valueInfo = []; + for (var i = 0; i < object.valueInfo.length; ++i) { + if (typeof object.valueInfo[i] !== "object") + throw TypeError(".onnx.GraphProto.valueInfo: object expected"); + message.valueInfo[i] = $root.onnx.ValueInfoProto.fromObject(object.valueInfo[i]); + } + } + if (object.quantizationAnnotation) { + if (!Array.isArray(object.quantizationAnnotation)) + throw TypeError(".onnx.GraphProto.quantizationAnnotation: array expected"); + message.quantizationAnnotation = []; + for (var i = 0; i < object.quantizationAnnotation.length; ++i) { + if (typeof object.quantizationAnnotation[i] !== "object") + throw TypeError(".onnx.GraphProto.quantizationAnnotation: object expected"); + message.quantizationAnnotation[i] = $root.onnx.TensorAnnotation.fromObject(object.quantizationAnnotation[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a GraphProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.GraphProto + * @static + * @param {onnx.GraphProto} message GraphProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + GraphProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.node = []; + object.initializer = []; + object.input = []; + object.output = []; + object.valueInfo = []; + object.quantizationAnnotation = []; + object.sparseInitializer = []; + } + if (options.defaults) { + object.name = ""; + object.docString = ""; + } + if (message.node && message.node.length) { + object.node = []; + for (var j = 0; j < message.node.length; ++j) + object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.initializer && message.initializer.length) { + object.initializer = []; + for (var j = 0; j < message.initializer.length; ++j) + object.initializer[j] = $root.onnx.TensorProto.toObject(message.initializer[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = $root.onnx.ValueInfoProto.toObject(message.input[j], options); + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = $root.onnx.ValueInfoProto.toObject(message.output[j], options); + } + if (message.valueInfo && message.valueInfo.length) { + object.valueInfo = []; + for (var j = 0; j < message.valueInfo.length; ++j) + object.valueInfo[j] = $root.onnx.ValueInfoProto.toObject(message.valueInfo[j], options); + } + if (message.quantizationAnnotation && message.quantizationAnnotation.length) { + object.quantizationAnnotation = []; + for (var j = 0; j < message.quantizationAnnotation.length; ++j) + object.quantizationAnnotation[j] = $root.onnx.TensorAnnotation.toObject(message.quantizationAnnotation[j], options); + } + if (message.sparseInitializer && message.sparseInitializer.length) { + object.sparseInitializer = []; + for (var j = 0; j < message.sparseInitializer.length; ++j) + object.sparseInitializer[j] = $root.onnx.SparseTensorProto.toObject(message.sparseInitializer[j], options); + } + return object; + }; + + /** + * Converts this GraphProto to JSON. + * @function toJSON + * @memberof onnx.GraphProto + * @instance + * @returns {Object.} JSON object + */ + GraphProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for GraphProto + * @function getTypeUrl + * @memberof onnx.GraphProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + GraphProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.GraphProto"; + }; + + return GraphProto; + })(); + + onnx.TensorProto = (function() { + + /** + * Properties of a TensorProto. + * @memberof onnx + * @interface ITensorProto + * @property {Array.|null} [dims] TensorProto dims + * @property {number|null} [dataType] TensorProto dataType + * @property {onnx.TensorProto.ISegment|null} [segment] TensorProto segment + * @property {Array.|null} [floatData] TensorProto floatData + * @property {Array.|null} [int32Data] TensorProto int32Data + * @property {Array.|null} [stringData] TensorProto stringData + * @property {Array.|null} [int64Data] TensorProto int64Data + * @property {string|null} [name] TensorProto name + * @property {string|null} [docString] TensorProto docString + * @property {Uint8Array|null} [rawData] TensorProto rawData + * @property {Array.|null} [externalData] TensorProto externalData + * @property {onnx.TensorProto.DataLocation|null} [dataLocation] TensorProto dataLocation + * @property {Array.|null} [doubleData] TensorProto doubleData + * @property {Array.|null} [uint64Data] TensorProto uint64Data + */ + + /** + * Constructs a new TensorProto. + * @memberof onnx + * @classdesc Represents a TensorProto. + * @implements ITensorProto + * @constructor + * @param {onnx.ITensorProto=} [properties] Properties to set + */ + function TensorProto(properties) { + this.dims = []; + this.floatData = []; + this.int32Data = []; + this.stringData = []; + this.int64Data = []; + this.externalData = []; + this.doubleData = []; + this.uint64Data = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorProto dims. + * @member {Array.} dims + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dims = $util.emptyArray; + + /** + * TensorProto dataType. + * @member {number} dataType + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dataType = 0; + + /** + * TensorProto segment. + * @member {onnx.TensorProto.ISegment|null|undefined} segment + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.segment = null; + + /** + * TensorProto floatData. + * @member {Array.} floatData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.floatData = $util.emptyArray; + + /** + * TensorProto int32Data. + * @member {Array.} int32Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.int32Data = $util.emptyArray; + + /** + * TensorProto stringData. + * @member {Array.} stringData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.stringData = $util.emptyArray; + + /** + * TensorProto int64Data. + * @member {Array.} int64Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.int64Data = $util.emptyArray; + + /** + * TensorProto name. + * @member {string} name + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.name = ""; + + /** + * TensorProto docString. + * @member {string} docString + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.docString = ""; + + /** + * TensorProto rawData. + * @member {Uint8Array} rawData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.rawData = $util.newBuffer([]); + + /** + * TensorProto externalData. + * @member {Array.} externalData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.externalData = $util.emptyArray; + + /** + * TensorProto dataLocation. + * @member {onnx.TensorProto.DataLocation} dataLocation + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dataLocation = 0; + + /** + * TensorProto doubleData. + * @member {Array.} doubleData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.doubleData = $util.emptyArray; + + /** + * TensorProto uint64Data. + * @member {Array.} uint64Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.uint64Data = $util.emptyArray; + + /** + * Creates a new TensorProto instance using the specified properties. + * @function create + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto=} [properties] Properties to set + * @returns {onnx.TensorProto} TensorProto instance + */ + TensorProto.create = function create(properties) { + return new TensorProto(properties); + }; + + /** + * Encodes the specified TensorProto message. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @function encode + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto} message TensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.dims != null && message.dims.length) { + writer.uint32(/* id 1, wireType 2 =*/10).fork(); + for (var i = 0; i < message.dims.length; ++i) + writer.int64(message.dims[i]); + writer.ldelim(); + } + if (message.dataType != null && Object.hasOwnProperty.call(message, "dataType")) + writer.uint32(/* id 2, wireType 0 =*/16).int32(message.dataType); + if (message.segment != null && Object.hasOwnProperty.call(message, "segment")) + $root.onnx.TensorProto.Segment.encode(message.segment, writer.uint32(/* id 3, wireType 2 =*/26).fork()).ldelim(); + if (message.floatData != null && message.floatData.length) { + writer.uint32(/* id 4, wireType 2 =*/34).fork(); + for (var i = 0; i < message.floatData.length; ++i) + writer.float(message.floatData[i]); + writer.ldelim(); + } + if (message.int32Data != null && message.int32Data.length) { + writer.uint32(/* id 5, wireType 2 =*/42).fork(); + for (var i = 0; i < message.int32Data.length; ++i) + writer.int32(message.int32Data[i]); + writer.ldelim(); + } + if (message.stringData != null && message.stringData.length) + for (var i = 0; i < message.stringData.length; ++i) + writer.uint32(/* id 6, wireType 2 =*/50).bytes(message.stringData[i]); + if (message.int64Data != null && message.int64Data.length) { + writer.uint32(/* id 7, wireType 2 =*/58).fork(); + for (var i = 0; i < message.int64Data.length; ++i) + writer.int64(message.int64Data[i]); + writer.ldelim(); + } + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 8, wireType 2 =*/66).string(message.name); + if (message.rawData != null && Object.hasOwnProperty.call(message, "rawData")) + writer.uint32(/* id 9, wireType 2 =*/74).bytes(message.rawData); + if (message.doubleData != null && message.doubleData.length) { + writer.uint32(/* id 10, wireType 2 =*/82).fork(); + for (var i = 0; i < message.doubleData.length; ++i) + writer.double(message.doubleData[i]); + writer.ldelim(); + } + if (message.uint64Data != null && message.uint64Data.length) { + writer.uint32(/* id 11, wireType 2 =*/90).fork(); + for (var i = 0; i < message.uint64Data.length; ++i) + writer.uint64(message.uint64Data[i]); + writer.ldelim(); + } + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 12, wireType 2 =*/98).string(message.docString); + if (message.externalData != null && message.externalData.length) + for (var i = 0; i < message.externalData.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.externalData[i], writer.uint32(/* id 13, wireType 2 =*/106).fork()).ldelim(); + if (message.dataLocation != null && Object.hasOwnProperty.call(message, "dataLocation")) + writer.uint32(/* id 14, wireType 0 =*/112).int32(message.dataLocation); + return writer; + }; + + /** + * Encodes the specified TensorProto message, length delimited. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto} message TensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorProto} TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.dims && message.dims.length)) + message.dims = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.dims.push(reader.int64()); + } else + message.dims.push(reader.int64()); + break; + } + case 2: { + message.dataType = reader.int32(); + break; + } + case 3: { + message.segment = $root.onnx.TensorProto.Segment.decode(reader, reader.uint32()); + break; + } + case 4: { + if (!(message.floatData && message.floatData.length)) + message.floatData = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.floatData.push(reader.float()); + } else + message.floatData.push(reader.float()); + break; + } + case 5: { + if (!(message.int32Data && message.int32Data.length)) + message.int32Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.int32Data.push(reader.int32()); + } else + message.int32Data.push(reader.int32()); + break; + } + case 6: { + if (!(message.stringData && message.stringData.length)) + message.stringData = []; + message.stringData.push(reader.bytes()); + break; + } + case 7: { + if (!(message.int64Data && message.int64Data.length)) + message.int64Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.int64Data.push(reader.int64()); + } else + message.int64Data.push(reader.int64()); + break; + } + case 8: { + message.name = reader.string(); + break; + } + case 12: { + message.docString = reader.string(); + break; + } + case 9: { + message.rawData = reader.bytes(); + break; + } + case 13: { + if (!(message.externalData && message.externalData.length)) + message.externalData = []; + message.externalData.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 14: { + message.dataLocation = reader.int32(); + break; + } + case 10: { + if (!(message.doubleData && message.doubleData.length)) + message.doubleData = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.doubleData.push(reader.double()); + } else + message.doubleData.push(reader.double()); + break; + } + case 11: { + if (!(message.uint64Data && message.uint64Data.length)) + message.uint64Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.uint64Data.push(reader.uint64()); + } else + message.uint64Data.push(reader.uint64()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorProto} TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorProto message. + * @function verify + * @memberof onnx.TensorProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.dims != null && message.hasOwnProperty("dims")) { + if (!Array.isArray(message.dims)) + return "dims: array expected"; + for (var i = 0; i < message.dims.length; ++i) + if (!$util.isInteger(message.dims[i]) && !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high))) + return "dims: integer|Long[] expected"; + } + if (message.dataType != null && message.hasOwnProperty("dataType")) + if (!$util.isInteger(message.dataType)) + return "dataType: integer expected"; + if (message.segment != null && message.hasOwnProperty("segment")) { + var error = $root.onnx.TensorProto.Segment.verify(message.segment); + if (error) + return "segment." + error; + } + if (message.floatData != null && message.hasOwnProperty("floatData")) { + if (!Array.isArray(message.floatData)) + return "floatData: array expected"; + for (var i = 0; i < message.floatData.length; ++i) + if (typeof message.floatData[i] !== "number") + return "floatData: number[] expected"; + } + if (message.int32Data != null && message.hasOwnProperty("int32Data")) { + if (!Array.isArray(message.int32Data)) + return "int32Data: array expected"; + for (var i = 0; i < message.int32Data.length; ++i) + if (!$util.isInteger(message.int32Data[i])) + return "int32Data: integer[] expected"; + } + if (message.stringData != null && message.hasOwnProperty("stringData")) { + if (!Array.isArray(message.stringData)) + return "stringData: array expected"; + for (var i = 0; i < message.stringData.length; ++i) + if (!(message.stringData[i] && typeof message.stringData[i].length === "number" || $util.isString(message.stringData[i]))) + return "stringData: buffer[] expected"; + } + if (message.int64Data != null && message.hasOwnProperty("int64Data")) { + if (!Array.isArray(message.int64Data)) + return "int64Data: array expected"; + for (var i = 0; i < message.int64Data.length; ++i) + if (!$util.isInteger(message.int64Data[i]) && !(message.int64Data[i] && $util.isInteger(message.int64Data[i].low) && $util.isInteger(message.int64Data[i].high))) + return "int64Data: integer|Long[] expected"; + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.rawData != null && message.hasOwnProperty("rawData")) + if (!(message.rawData && typeof message.rawData.length === "number" || $util.isString(message.rawData))) + return "rawData: buffer expected"; + if (message.externalData != null && message.hasOwnProperty("externalData")) { + if (!Array.isArray(message.externalData)) + return "externalData: array expected"; + for (var i = 0; i < message.externalData.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.externalData[i]); + if (error) + return "externalData." + error; + } + } + if (message.dataLocation != null && message.hasOwnProperty("dataLocation")) + switch (message.dataLocation) { + default: + return "dataLocation: enum value expected"; + case 0: + case 1: + break; + } + if (message.doubleData != null && message.hasOwnProperty("doubleData")) { + if (!Array.isArray(message.doubleData)) + return "doubleData: array expected"; + for (var i = 0; i < message.doubleData.length; ++i) + if (typeof message.doubleData[i] !== "number") + return "doubleData: number[] expected"; + } + if (message.uint64Data != null && message.hasOwnProperty("uint64Data")) { + if (!Array.isArray(message.uint64Data)) + return "uint64Data: array expected"; + for (var i = 0; i < message.uint64Data.length; ++i) + if (!$util.isInteger(message.uint64Data[i]) && !(message.uint64Data[i] && $util.isInteger(message.uint64Data[i].low) && $util.isInteger(message.uint64Data[i].high))) + return "uint64Data: integer|Long[] expected"; + } + return null; + }; + + /** + * Creates a TensorProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorProto} TensorProto + */ + TensorProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorProto) + return object; + var message = new $root.onnx.TensorProto(); + if (object.dims) { + if (!Array.isArray(object.dims)) + throw TypeError(".onnx.TensorProto.dims: array expected"); + message.dims = []; + for (var i = 0; i < object.dims.length; ++i) + if ($util.Long) + (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; + else if (typeof object.dims[i] === "string") + message.dims[i] = parseInt(object.dims[i], 10); + else if (typeof object.dims[i] === "number") + message.dims[i] = object.dims[i]; + else if (typeof object.dims[i] === "object") + message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); + } + if (object.dataType != null) + message.dataType = object.dataType | 0; + if (object.segment != null) { + if (typeof object.segment !== "object") + throw TypeError(".onnx.TensorProto.segment: object expected"); + message.segment = $root.onnx.TensorProto.Segment.fromObject(object.segment); + } + if (object.floatData) { + if (!Array.isArray(object.floatData)) + throw TypeError(".onnx.TensorProto.floatData: array expected"); + message.floatData = []; + for (var i = 0; i < object.floatData.length; ++i) + message.floatData[i] = Number(object.floatData[i]); + } + if (object.int32Data) { + if (!Array.isArray(object.int32Data)) + throw TypeError(".onnx.TensorProto.int32Data: array expected"); + message.int32Data = []; + for (var i = 0; i < object.int32Data.length; ++i) + message.int32Data[i] = object.int32Data[i] | 0; + } + if (object.stringData) { + if (!Array.isArray(object.stringData)) + throw TypeError(".onnx.TensorProto.stringData: array expected"); + message.stringData = []; + for (var i = 0; i < object.stringData.length; ++i) + if (typeof object.stringData[i] === "string") + $util.base64.decode(object.stringData[i], message.stringData[i] = $util.newBuffer($util.base64.length(object.stringData[i])), 0); + else if (object.stringData[i].length >= 0) + message.stringData[i] = object.stringData[i]; + } + if (object.int64Data) { + if (!Array.isArray(object.int64Data)) + throw TypeError(".onnx.TensorProto.int64Data: array expected"); + message.int64Data = []; + for (var i = 0; i < object.int64Data.length; ++i) + if ($util.Long) + (message.int64Data[i] = $util.Long.fromValue(object.int64Data[i])).unsigned = false; + else if (typeof object.int64Data[i] === "string") + message.int64Data[i] = parseInt(object.int64Data[i], 10); + else if (typeof object.int64Data[i] === "number") + message.int64Data[i] = object.int64Data[i]; + else if (typeof object.int64Data[i] === "object") + message.int64Data[i] = new $util.LongBits(object.int64Data[i].low >>> 0, object.int64Data[i].high >>> 0).toNumber(); + } + if (object.name != null) + message.name = String(object.name); + if (object.docString != null) + message.docString = String(object.docString); + if (object.rawData != null) + if (typeof object.rawData === "string") + $util.base64.decode(object.rawData, message.rawData = $util.newBuffer($util.base64.length(object.rawData)), 0); + else if (object.rawData.length >= 0) + message.rawData = object.rawData; + if (object.externalData) { + if (!Array.isArray(object.externalData)) + throw TypeError(".onnx.TensorProto.externalData: array expected"); + message.externalData = []; + for (var i = 0; i < object.externalData.length; ++i) { + if (typeof object.externalData[i] !== "object") + throw TypeError(".onnx.TensorProto.externalData: object expected"); + message.externalData[i] = $root.onnx.StringStringEntryProto.fromObject(object.externalData[i]); + } + } + switch (object.dataLocation) { + default: + if (typeof object.dataLocation === "number") { + message.dataLocation = object.dataLocation; + break; + } + break; + case "DEFAULT": + case 0: + message.dataLocation = 0; + break; + case "EXTERNAL": + case 1: + message.dataLocation = 1; + break; + } + if (object.doubleData) { + if (!Array.isArray(object.doubleData)) + throw TypeError(".onnx.TensorProto.doubleData: array expected"); + message.doubleData = []; + for (var i = 0; i < object.doubleData.length; ++i) + message.doubleData[i] = Number(object.doubleData[i]); + } + if (object.uint64Data) { + if (!Array.isArray(object.uint64Data)) + throw TypeError(".onnx.TensorProto.uint64Data: array expected"); + message.uint64Data = []; + for (var i = 0; i < object.uint64Data.length; ++i) + if ($util.Long) + (message.uint64Data[i] = $util.Long.fromValue(object.uint64Data[i])).unsigned = true; + else if (typeof object.uint64Data[i] === "string") + message.uint64Data[i] = parseInt(object.uint64Data[i], 10); + else if (typeof object.uint64Data[i] === "number") + message.uint64Data[i] = object.uint64Data[i]; + else if (typeof object.uint64Data[i] === "object") + message.uint64Data[i] = new $util.LongBits(object.uint64Data[i].low >>> 0, object.uint64Data[i].high >>> 0).toNumber(true); + } + return message; + }; + + /** + * Creates a plain object from a TensorProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorProto + * @static + * @param {onnx.TensorProto} message TensorProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.dims = []; + object.floatData = []; + object.int32Data = []; + object.stringData = []; + object.int64Data = []; + object.doubleData = []; + object.uint64Data = []; + object.externalData = []; + } + if (options.defaults) { + object.dataType = 0; + object.segment = null; + object.name = ""; + if (options.bytes === String) + object.rawData = ""; + else { + object.rawData = []; + if (options.bytes !== Array) + object.rawData = $util.newBuffer(object.rawData); + } + object.docString = ""; + object.dataLocation = options.enums === String ? "DEFAULT" : 0; + } + if (message.dims && message.dims.length) { + object.dims = []; + for (var j = 0; j < message.dims.length; ++j) + if (typeof message.dims[j] === "number") + object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; + else + object.dims[j] = options.longs === String ? $util.Long.prototype.toString.call(message.dims[j]) : options.longs === Number ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() : message.dims[j]; + } + if (message.dataType != null && message.hasOwnProperty("dataType")) + object.dataType = message.dataType; + if (message.segment != null && message.hasOwnProperty("segment")) + object.segment = $root.onnx.TensorProto.Segment.toObject(message.segment, options); + if (message.floatData && message.floatData.length) { + object.floatData = []; + for (var j = 0; j < message.floatData.length; ++j) + object.floatData[j] = options.json && !isFinite(message.floatData[j]) ? String(message.floatData[j]) : message.floatData[j]; + } + if (message.int32Data && message.int32Data.length) { + object.int32Data = []; + for (var j = 0; j < message.int32Data.length; ++j) + object.int32Data[j] = message.int32Data[j]; + } + if (message.stringData && message.stringData.length) { + object.stringData = []; + for (var j = 0; j < message.stringData.length; ++j) + object.stringData[j] = options.bytes === String ? $util.base64.encode(message.stringData[j], 0, message.stringData[j].length) : options.bytes === Array ? Array.prototype.slice.call(message.stringData[j]) : message.stringData[j]; + } + if (message.int64Data && message.int64Data.length) { + object.int64Data = []; + for (var j = 0; j < message.int64Data.length; ++j) + if (typeof message.int64Data[j] === "number") + object.int64Data[j] = options.longs === String ? String(message.int64Data[j]) : message.int64Data[j]; + else + object.int64Data[j] = options.longs === String ? $util.Long.prototype.toString.call(message.int64Data[j]) : options.longs === Number ? new $util.LongBits(message.int64Data[j].low >>> 0, message.int64Data[j].high >>> 0).toNumber() : message.int64Data[j]; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.rawData != null && message.hasOwnProperty("rawData")) + object.rawData = options.bytes === String ? $util.base64.encode(message.rawData, 0, message.rawData.length) : options.bytes === Array ? Array.prototype.slice.call(message.rawData) : message.rawData; + if (message.doubleData && message.doubleData.length) { + object.doubleData = []; + for (var j = 0; j < message.doubleData.length; ++j) + object.doubleData[j] = options.json && !isFinite(message.doubleData[j]) ? String(message.doubleData[j]) : message.doubleData[j]; + } + if (message.uint64Data && message.uint64Data.length) { + object.uint64Data = []; + for (var j = 0; j < message.uint64Data.length; ++j) + if (typeof message.uint64Data[j] === "number") + object.uint64Data[j] = options.longs === String ? String(message.uint64Data[j]) : message.uint64Data[j]; + else + object.uint64Data[j] = options.longs === String ? $util.Long.prototype.toString.call(message.uint64Data[j]) : options.longs === Number ? new $util.LongBits(message.uint64Data[j].low >>> 0, message.uint64Data[j].high >>> 0).toNumber(true) : message.uint64Data[j]; + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.externalData && message.externalData.length) { + object.externalData = []; + for (var j = 0; j < message.externalData.length; ++j) + object.externalData[j] = $root.onnx.StringStringEntryProto.toObject(message.externalData[j], options); + } + if (message.dataLocation != null && message.hasOwnProperty("dataLocation")) + object.dataLocation = options.enums === String ? $root.onnx.TensorProto.DataLocation[message.dataLocation] === undefined ? message.dataLocation : $root.onnx.TensorProto.DataLocation[message.dataLocation] : message.dataLocation; + return object; + }; + + /** + * Converts this TensorProto to JSON. + * @function toJSON + * @memberof onnx.TensorProto + * @instance + * @returns {Object.} JSON object + */ + TensorProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorProto + * @function getTypeUrl + * @memberof onnx.TensorProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorProto"; + }; + + /** + * DataType enum. + * @name onnx.TensorProto.DataType + * @enum {number} + * @property {number} UNDEFINED=0 UNDEFINED value + * @property {number} FLOAT=1 FLOAT value + * @property {number} UINT8=2 UINT8 value + * @property {number} INT8=3 INT8 value + * @property {number} UINT16=4 UINT16 value + * @property {number} INT16=5 INT16 value + * @property {number} INT32=6 INT32 value + * @property {number} INT64=7 INT64 value + * @property {number} STRING=8 STRING value + * @property {number} BOOL=9 BOOL value + * @property {number} FLOAT16=10 FLOAT16 value + * @property {number} DOUBLE=11 DOUBLE value + * @property {number} UINT32=12 UINT32 value + * @property {number} UINT64=13 UINT64 value + * @property {number} COMPLEX64=14 COMPLEX64 value + * @property {number} COMPLEX128=15 COMPLEX128 value + * @property {number} BFLOAT16=16 BFLOAT16 value + * @property {number} FLOAT8E4M3FN=17 FLOAT8E4M3FN value + * @property {number} FLOAT8E4M3FNUZ=18 FLOAT8E4M3FNUZ value + * @property {number} FLOAT8E5M2=19 FLOAT8E5M2 value + * @property {number} FLOAT8E5M2FNUZ=20 FLOAT8E5M2FNUZ value + */ + TensorProto.DataType = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "UNDEFINED"] = 0; + values[valuesById[1] = "FLOAT"] = 1; + values[valuesById[2] = "UINT8"] = 2; + values[valuesById[3] = "INT8"] = 3; + values[valuesById[4] = "UINT16"] = 4; + values[valuesById[5] = "INT16"] = 5; + values[valuesById[6] = "INT32"] = 6; + values[valuesById[7] = "INT64"] = 7; + values[valuesById[8] = "STRING"] = 8; + values[valuesById[9] = "BOOL"] = 9; + values[valuesById[10] = "FLOAT16"] = 10; + values[valuesById[11] = "DOUBLE"] = 11; + values[valuesById[12] = "UINT32"] = 12; + values[valuesById[13] = "UINT64"] = 13; + values[valuesById[14] = "COMPLEX64"] = 14; + values[valuesById[15] = "COMPLEX128"] = 15; + values[valuesById[16] = "BFLOAT16"] = 16; + values[valuesById[17] = "FLOAT8E4M3FN"] = 17; + values[valuesById[18] = "FLOAT8E4M3FNUZ"] = 18; + values[valuesById[19] = "FLOAT8E5M2"] = 19; + values[valuesById[20] = "FLOAT8E5M2FNUZ"] = 20; + return values; + })(); + + TensorProto.Segment = (function() { + + /** + * Properties of a Segment. + * @memberof onnx.TensorProto + * @interface ISegment + * @property {number|Long|null} [begin] Segment begin + * @property {number|Long|null} [end] Segment end + */ + + /** + * Constructs a new Segment. + * @memberof onnx.TensorProto + * @classdesc Represents a Segment. + * @implements ISegment + * @constructor + * @param {onnx.TensorProto.ISegment=} [properties] Properties to set + */ + function Segment(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Segment begin. + * @member {number|Long} begin + * @memberof onnx.TensorProto.Segment + * @instance + */ + Segment.prototype.begin = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * Segment end. + * @member {number|Long} end + * @memberof onnx.TensorProto.Segment + * @instance + */ + Segment.prototype.end = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * Creates a new Segment instance using the specified properties. + * @function create + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment=} [properties] Properties to set + * @returns {onnx.TensorProto.Segment} Segment instance + */ + Segment.create = function create(properties) { + return new Segment(properties); + }; + + /** + * Encodes the specified Segment message. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. + * @function encode + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Segment.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.begin != null && Object.hasOwnProperty.call(message, "begin")) + writer.uint32(/* id 1, wireType 0 =*/8).int64(message.begin); + if (message.end != null && Object.hasOwnProperty.call(message, "end")) + writer.uint32(/* id 2, wireType 0 =*/16).int64(message.end); + return writer; + }; + + /** + * Encodes the specified Segment message, length delimited. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Segment.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Segment message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorProto.Segment + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorProto.Segment} Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Segment.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorProto.Segment(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.begin = reader.int64(); + break; + } + case 2: { + message.end = reader.int64(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Segment message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorProto.Segment + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorProto.Segment} Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Segment.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Segment message. + * @function verify + * @memberof onnx.TensorProto.Segment + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Segment.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.begin != null && message.hasOwnProperty("begin")) + if (!$util.isInteger(message.begin) && !(message.begin && $util.isInteger(message.begin.low) && $util.isInteger(message.begin.high))) + return "begin: integer|Long expected"; + if (message.end != null && message.hasOwnProperty("end")) + if (!$util.isInteger(message.end) && !(message.end && $util.isInteger(message.end.low) && $util.isInteger(message.end.high))) + return "end: integer|Long expected"; + return null; + }; + + /** + * Creates a Segment message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorProto.Segment + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorProto.Segment} Segment + */ + Segment.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorProto.Segment) + return object; + var message = new $root.onnx.TensorProto.Segment(); + if (object.begin != null) + if ($util.Long) + (message.begin = $util.Long.fromValue(object.begin)).unsigned = false; + else if (typeof object.begin === "string") + message.begin = parseInt(object.begin, 10); + else if (typeof object.begin === "number") + message.begin = object.begin; + else if (typeof object.begin === "object") + message.begin = new $util.LongBits(object.begin.low >>> 0, object.begin.high >>> 0).toNumber(); + if (object.end != null) + if ($util.Long) + (message.end = $util.Long.fromValue(object.end)).unsigned = false; + else if (typeof object.end === "string") + message.end = parseInt(object.end, 10); + else if (typeof object.end === "number") + message.end = object.end; + else if (typeof object.end === "object") + message.end = new $util.LongBits(object.end.low >>> 0, object.end.high >>> 0).toNumber(); + return message; + }; + + /** + * Creates a plain object from a Segment message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.Segment} message Segment + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Segment.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.begin = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.begin = options.longs === String ? "0" : 0; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.end = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.end = options.longs === String ? "0" : 0; + } + if (message.begin != null && message.hasOwnProperty("begin")) + if (typeof message.begin === "number") + object.begin = options.longs === String ? String(message.begin) : message.begin; + else + object.begin = options.longs === String ? $util.Long.prototype.toString.call(message.begin) : options.longs === Number ? new $util.LongBits(message.begin.low >>> 0, message.begin.high >>> 0).toNumber() : message.begin; + if (message.end != null && message.hasOwnProperty("end")) + if (typeof message.end === "number") + object.end = options.longs === String ? String(message.end) : message.end; + else + object.end = options.longs === String ? $util.Long.prototype.toString.call(message.end) : options.longs === Number ? new $util.LongBits(message.end.low >>> 0, message.end.high >>> 0).toNumber() : message.end; + return object; + }; + + /** + * Converts this Segment to JSON. + * @function toJSON + * @memberof onnx.TensorProto.Segment + * @instance + * @returns {Object.} JSON object + */ + Segment.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Segment + * @function getTypeUrl + * @memberof onnx.TensorProto.Segment + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Segment.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorProto.Segment"; + }; + + return Segment; + })(); + + /** + * DataLocation enum. + * @name onnx.TensorProto.DataLocation + * @enum {number} + * @property {number} DEFAULT=0 DEFAULT value + * @property {number} EXTERNAL=1 EXTERNAL value + */ + TensorProto.DataLocation = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "DEFAULT"] = 0; + values[valuesById[1] = "EXTERNAL"] = 1; + return values; + })(); + + return TensorProto; + })(); + + onnx.SparseTensorProto = (function() { + + /** + * Properties of a SparseTensorProto. + * @memberof onnx + * @interface ISparseTensorProto + * @property {onnx.ITensorProto|null} [values] SparseTensorProto values + * @property {onnx.ITensorProto|null} [indices] SparseTensorProto indices + * @property {Array.|null} [dims] SparseTensorProto dims + */ + + /** + * Constructs a new SparseTensorProto. + * @memberof onnx + * @classdesc Represents a SparseTensorProto. + * @implements ISparseTensorProto + * @constructor + * @param {onnx.ISparseTensorProto=} [properties] Properties to set + */ + function SparseTensorProto(properties) { + this.dims = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * SparseTensorProto values. + * @member {onnx.ITensorProto|null|undefined} values + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.values = null; + + /** + * SparseTensorProto indices. + * @member {onnx.ITensorProto|null|undefined} indices + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.indices = null; + + /** + * SparseTensorProto dims. + * @member {Array.} dims + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.dims = $util.emptyArray; + + /** + * Creates a new SparseTensorProto instance using the specified properties. + * @function create + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto=} [properties] Properties to set + * @returns {onnx.SparseTensorProto} SparseTensorProto instance + */ + SparseTensorProto.create = function create(properties) { + return new SparseTensorProto(properties); + }; + + /** + * Encodes the specified SparseTensorProto message. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. + * @function encode + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensorProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.values != null && Object.hasOwnProperty.call(message, "values")) + $root.onnx.TensorProto.encode(message.values, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.indices != null && Object.hasOwnProperty.call(message, "indices")) + $root.onnx.TensorProto.encode(message.indices, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.dims != null && message.dims.length) { + writer.uint32(/* id 3, wireType 2 =*/26).fork(); + for (var i = 0; i < message.dims.length; ++i) + writer.int64(message.dims[i]); + writer.ldelim(); + } + return writer; + }; + + /** + * Encodes the specified SparseTensorProto message, length delimited. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensorProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.SparseTensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.SparseTensorProto} SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensorProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.SparseTensorProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.values = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 2: { + message.indices = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 3: { + if (!(message.dims && message.dims.length)) + message.dims = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.dims.push(reader.int64()); + } else + message.dims.push(reader.int64()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.SparseTensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.SparseTensorProto} SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensorProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a SparseTensorProto message. + * @function verify + * @memberof onnx.SparseTensorProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + SparseTensorProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.values != null && message.hasOwnProperty("values")) { + var error = $root.onnx.TensorProto.verify(message.values); + if (error) + return "values." + error; + } + if (message.indices != null && message.hasOwnProperty("indices")) { + var error = $root.onnx.TensorProto.verify(message.indices); + if (error) + return "indices." + error; + } + if (message.dims != null && message.hasOwnProperty("dims")) { + if (!Array.isArray(message.dims)) + return "dims: array expected"; + for (var i = 0; i < message.dims.length; ++i) + if (!$util.isInteger(message.dims[i]) && !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high))) + return "dims: integer|Long[] expected"; + } + return null; + }; + + /** + * Creates a SparseTensorProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.SparseTensorProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.SparseTensorProto} SparseTensorProto + */ + SparseTensorProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.SparseTensorProto) + return object; + var message = new $root.onnx.SparseTensorProto(); + if (object.values != null) { + if (typeof object.values !== "object") + throw TypeError(".onnx.SparseTensorProto.values: object expected"); + message.values = $root.onnx.TensorProto.fromObject(object.values); + } + if (object.indices != null) { + if (typeof object.indices !== "object") + throw TypeError(".onnx.SparseTensorProto.indices: object expected"); + message.indices = $root.onnx.TensorProto.fromObject(object.indices); + } + if (object.dims) { + if (!Array.isArray(object.dims)) + throw TypeError(".onnx.SparseTensorProto.dims: array expected"); + message.dims = []; + for (var i = 0; i < object.dims.length; ++i) + if ($util.Long) + (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; + else if (typeof object.dims[i] === "string") + message.dims[i] = parseInt(object.dims[i], 10); + else if (typeof object.dims[i] === "number") + message.dims[i] = object.dims[i]; + else if (typeof object.dims[i] === "object") + message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); + } + return message; + }; + + /** + * Creates a plain object from a SparseTensorProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.SparseTensorProto} message SparseTensorProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + SparseTensorProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) + object.dims = []; + if (options.defaults) { + object.values = null; + object.indices = null; + } + if (message.values != null && message.hasOwnProperty("values")) + object.values = $root.onnx.TensorProto.toObject(message.values, options); + if (message.indices != null && message.hasOwnProperty("indices")) + object.indices = $root.onnx.TensorProto.toObject(message.indices, options); + if (message.dims && message.dims.length) { + object.dims = []; + for (var j = 0; j < message.dims.length; ++j) + if (typeof message.dims[j] === "number") + object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; + else + object.dims[j] = options.longs === String ? $util.Long.prototype.toString.call(message.dims[j]) : options.longs === Number ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() : message.dims[j]; + } + return object; + }; + + /** + * Converts this SparseTensorProto to JSON. + * @function toJSON + * @memberof onnx.SparseTensorProto + * @instance + * @returns {Object.} JSON object + */ + SparseTensorProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for SparseTensorProto + * @function getTypeUrl + * @memberof onnx.SparseTensorProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + SparseTensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.SparseTensorProto"; + }; + + return SparseTensorProto; + })(); + + onnx.TensorShapeProto = (function() { + + /** + * Properties of a TensorShapeProto. + * @memberof onnx + * @interface ITensorShapeProto + * @property {Array.|null} [dim] TensorShapeProto dim + */ + + /** + * Constructs a new TensorShapeProto. + * @memberof onnx + * @classdesc Represents a TensorShapeProto. + * @implements ITensorShapeProto + * @constructor + * @param {onnx.ITensorShapeProto=} [properties] Properties to set + */ + function TensorShapeProto(properties) { + this.dim = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorShapeProto dim. + * @member {Array.} dim + * @memberof onnx.TensorShapeProto + * @instance + */ + TensorShapeProto.prototype.dim = $util.emptyArray; + + /** + * Creates a new TensorShapeProto instance using the specified properties. + * @function create + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto=} [properties] Properties to set + * @returns {onnx.TensorShapeProto} TensorShapeProto instance + */ + TensorShapeProto.create = function create(properties) { + return new TensorShapeProto(properties); + }; + + /** + * Encodes the specified TensorShapeProto message. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. + * @function encode + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorShapeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.dim != null && message.dim.length) + for (var i = 0; i < message.dim.length; ++i) + $root.onnx.TensorShapeProto.Dimension.encode(message.dim[i], writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TensorShapeProto message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorShapeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorShapeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorShapeProto} TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorShapeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorShapeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.dim && message.dim.length)) + message.dim = []; + message.dim.push($root.onnx.TensorShapeProto.Dimension.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorShapeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorShapeProto} TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorShapeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorShapeProto message. + * @function verify + * @memberof onnx.TensorShapeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorShapeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.dim != null && message.hasOwnProperty("dim")) { + if (!Array.isArray(message.dim)) + return "dim: array expected"; + for (var i = 0; i < message.dim.length; ++i) { + var error = $root.onnx.TensorShapeProto.Dimension.verify(message.dim[i]); + if (error) + return "dim." + error; + } + } + return null; + }; + + /** + * Creates a TensorShapeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorShapeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorShapeProto} TensorShapeProto + */ + TensorShapeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorShapeProto) + return object; + var message = new $root.onnx.TensorShapeProto(); + if (object.dim) { + if (!Array.isArray(object.dim)) + throw TypeError(".onnx.TensorShapeProto.dim: array expected"); + message.dim = []; + for (var i = 0; i < object.dim.length; ++i) { + if (typeof object.dim[i] !== "object") + throw TypeError(".onnx.TensorShapeProto.dim: object expected"); + message.dim[i] = $root.onnx.TensorShapeProto.Dimension.fromObject(object.dim[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TensorShapeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.TensorShapeProto} message TensorShapeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorShapeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) + object.dim = []; + if (message.dim && message.dim.length) { + object.dim = []; + for (var j = 0; j < message.dim.length; ++j) + object.dim[j] = $root.onnx.TensorShapeProto.Dimension.toObject(message.dim[j], options); + } + return object; + }; + + /** + * Converts this TensorShapeProto to JSON. + * @function toJSON + * @memberof onnx.TensorShapeProto + * @instance + * @returns {Object.} JSON object + */ + TensorShapeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorShapeProto + * @function getTypeUrl + * @memberof onnx.TensorShapeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorShapeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorShapeProto"; + }; + + TensorShapeProto.Dimension = (function() { + + /** + * Properties of a Dimension. + * @memberof onnx.TensorShapeProto + * @interface IDimension + * @property {number|Long|null} [dimValue] Dimension dimValue + * @property {string|null} [dimParam] Dimension dimParam + * @property {string|null} [denotation] Dimension denotation + */ + + /** + * Constructs a new Dimension. + * @memberof onnx.TensorShapeProto + * @classdesc Represents a Dimension. + * @implements IDimension + * @constructor + * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set + */ + function Dimension(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Dimension dimValue. + * @member {number|Long|null|undefined} dimValue + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.dimValue = null; + + /** + * Dimension dimParam. + * @member {string|null|undefined} dimParam + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.dimParam = null; + + /** + * Dimension denotation. + * @member {string} denotation + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.denotation = ""; + + // OneOf field names bound to virtual getters and setters + var $oneOfFields; + + /** + * Dimension value. + * @member {"dimValue"|"dimParam"|undefined} value + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Object.defineProperty(Dimension.prototype, "value", { + get: $util.oneOfGetter($oneOfFields = ["dimValue", "dimParam"]), + set: $util.oneOfSetter($oneOfFields) + }); + + /** + * Creates a new Dimension instance using the specified properties. + * @function create + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set + * @returns {onnx.TensorShapeProto.Dimension} Dimension instance + */ + Dimension.create = function create(properties) { + return new Dimension(properties); + }; + + /** + * Encodes the specified Dimension message. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @function encode + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Dimension.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.dimValue != null && Object.hasOwnProperty.call(message, "dimValue")) + writer.uint32(/* id 1, wireType 0 =*/8).int64(message.dimValue); + if (message.dimParam != null && Object.hasOwnProperty.call(message, "dimParam")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.dimParam); + if (message.denotation != null && Object.hasOwnProperty.call(message, "denotation")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.denotation); + return writer; + }; + + /** + * Encodes the specified Dimension message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Dimension.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Dimension message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorShapeProto.Dimension} Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Dimension.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorShapeProto.Dimension(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.dimValue = reader.int64(); + break; + } + case 2: { + message.dimParam = reader.string(); + break; + } + case 3: { + message.denotation = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Dimension message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorShapeProto.Dimension} Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Dimension.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Dimension message. + * @function verify + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Dimension.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + var properties = {}; + if (message.dimValue != null && message.hasOwnProperty("dimValue")) { + properties.value = 1; + if (!$util.isInteger(message.dimValue) && !(message.dimValue && $util.isInteger(message.dimValue.low) && $util.isInteger(message.dimValue.high))) + return "dimValue: integer|Long expected"; + } + if (message.dimParam != null && message.hasOwnProperty("dimParam")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + if (!$util.isString(message.dimParam)) + return "dimParam: string expected"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + if (!$util.isString(message.denotation)) + return "denotation: string expected"; + return null; + }; + + /** + * Creates a Dimension message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorShapeProto.Dimension} Dimension + */ + Dimension.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorShapeProto.Dimension) + return object; + var message = new $root.onnx.TensorShapeProto.Dimension(); + if (object.dimValue != null) + if ($util.Long) + (message.dimValue = $util.Long.fromValue(object.dimValue)).unsigned = false; + else if (typeof object.dimValue === "string") + message.dimValue = parseInt(object.dimValue, 10); + else if (typeof object.dimValue === "number") + message.dimValue = object.dimValue; + else if (typeof object.dimValue === "object") + message.dimValue = new $util.LongBits(object.dimValue.low >>> 0, object.dimValue.high >>> 0).toNumber(); + if (object.dimParam != null) + message.dimParam = String(object.dimParam); + if (object.denotation != null) + message.denotation = String(object.denotation); + return message; + }; + + /** + * Creates a plain object from a Dimension message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.Dimension} message Dimension + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Dimension.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.denotation = ""; + if (message.dimValue != null && message.hasOwnProperty("dimValue")) { + if (typeof message.dimValue === "number") + object.dimValue = options.longs === String ? String(message.dimValue) : message.dimValue; + else + object.dimValue = options.longs === String ? $util.Long.prototype.toString.call(message.dimValue) : options.longs === Number ? new $util.LongBits(message.dimValue.low >>> 0, message.dimValue.high >>> 0).toNumber() : message.dimValue; + if (options.oneofs) + object.value = "dimValue"; + } + if (message.dimParam != null && message.hasOwnProperty("dimParam")) { + object.dimParam = message.dimParam; + if (options.oneofs) + object.value = "dimParam"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + object.denotation = message.denotation; + return object; + }; + + /** + * Converts this Dimension to JSON. + * @function toJSON + * @memberof onnx.TensorShapeProto.Dimension + * @instance + * @returns {Object.} JSON object + */ + Dimension.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Dimension + * @function getTypeUrl + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Dimension.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorShapeProto.Dimension"; + }; + + return Dimension; + })(); + + return TensorShapeProto; + })(); + + onnx.TypeProto = (function() { + + /** + * Properties of a TypeProto. + * @memberof onnx + * @interface ITypeProto + * @property {onnx.TypeProto.ITensor|null} [tensorType] TypeProto tensorType + * @property {onnx.TypeProto.ISequence|null} [sequenceType] TypeProto sequenceType + * @property {onnx.TypeProto.IMap|null} [mapType] TypeProto mapType + * @property {onnx.TypeProto.IOptional|null} [optionalType] TypeProto optionalType + * @property {onnx.TypeProto.ISparseTensor|null} [sparseTensorType] TypeProto sparseTensorType + * @property {string|null} [denotation] TypeProto denotation + */ + + /** + * Constructs a new TypeProto. + * @memberof onnx + * @classdesc Represents a TypeProto. + * @implements ITypeProto + * @constructor + * @param {onnx.ITypeProto=} [properties] Properties to set + */ + function TypeProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TypeProto tensorType. + * @member {onnx.TypeProto.ITensor|null|undefined} tensorType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.tensorType = null; + + /** + * TypeProto sequenceType. + * @member {onnx.TypeProto.ISequence|null|undefined} sequenceType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.sequenceType = null; + + /** + * TypeProto mapType. + * @member {onnx.TypeProto.IMap|null|undefined} mapType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.mapType = null; + + /** + * TypeProto optionalType. + * @member {onnx.TypeProto.IOptional|null|undefined} optionalType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.optionalType = null; + + /** + * TypeProto sparseTensorType. + * @member {onnx.TypeProto.ISparseTensor|null|undefined} sparseTensorType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.sparseTensorType = null; + + /** + * TypeProto denotation. + * @member {string} denotation + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.denotation = ""; + + // OneOf field names bound to virtual getters and setters + var $oneOfFields; + + /** + * TypeProto value. + * @member {"tensorType"|"sequenceType"|"mapType"|"optionalType"|"sparseTensorType"|undefined} value + * @memberof onnx.TypeProto + * @instance + */ + Object.defineProperty(TypeProto.prototype, "value", { + get: $util.oneOfGetter($oneOfFields = ["tensorType", "sequenceType", "mapType", "optionalType", "sparseTensorType"]), + set: $util.oneOfSetter($oneOfFields) + }); + + /** + * Creates a new TypeProto instance using the specified properties. + * @function create + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto=} [properties] Properties to set + * @returns {onnx.TypeProto} TypeProto instance + */ + TypeProto.create = function create(properties) { + return new TypeProto(properties); + }; + + /** + * Encodes the specified TypeProto message. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto} message TypeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TypeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.tensorType != null && Object.hasOwnProperty.call(message, "tensorType")) + $root.onnx.TypeProto.Tensor.encode(message.tensorType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.sequenceType != null && Object.hasOwnProperty.call(message, "sequenceType")) + $root.onnx.TypeProto.Sequence.encode(message.sequenceType, writer.uint32(/* id 4, wireType 2 =*/34).fork()).ldelim(); + if (message.mapType != null && Object.hasOwnProperty.call(message, "mapType")) + $root.onnx.TypeProto.Map.encode(message.mapType, writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.denotation != null && Object.hasOwnProperty.call(message, "denotation")) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.denotation); + if (message.sparseTensorType != null && Object.hasOwnProperty.call(message, "sparseTensorType")) + $root.onnx.TypeProto.SparseTensor.encode(message.sparseTensorType, writer.uint32(/* id 8, wireType 2 =*/66).fork()).ldelim(); + if (message.optionalType != null && Object.hasOwnProperty.call(message, "optionalType")) + $root.onnx.TypeProto.Optional.encode(message.optionalType, writer.uint32(/* id 9, wireType 2 =*/74).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TypeProto message, length delimited. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto} message TypeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TypeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TypeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto} TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TypeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.tensorType = $root.onnx.TypeProto.Tensor.decode(reader, reader.uint32()); + break; + } + case 4: { + message.sequenceType = $root.onnx.TypeProto.Sequence.decode(reader, reader.uint32()); + break; + } + case 5: { + message.mapType = $root.onnx.TypeProto.Map.decode(reader, reader.uint32()); + break; + } + case 9: { + message.optionalType = $root.onnx.TypeProto.Optional.decode(reader, reader.uint32()); + break; + } + case 8: { + message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.decode(reader, reader.uint32()); + break; + } + case 6: { + message.denotation = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TypeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto} TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TypeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TypeProto message. + * @function verify + * @memberof onnx.TypeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TypeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + var properties = {}; + if (message.tensorType != null && message.hasOwnProperty("tensorType")) { + properties.value = 1; + { + var error = $root.onnx.TypeProto.Tensor.verify(message.tensorType); + if (error) + return "tensorType." + error; + } + } + if (message.sequenceType != null && message.hasOwnProperty("sequenceType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Sequence.verify(message.sequenceType); + if (error) + return "sequenceType." + error; + } + } + if (message.mapType != null && message.hasOwnProperty("mapType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Map.verify(message.mapType); + if (error) + return "mapType." + error; + } + } + if (message.optionalType != null && message.hasOwnProperty("optionalType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Optional.verify(message.optionalType); + if (error) + return "optionalType." + error; + } + } + if (message.sparseTensorType != null && message.hasOwnProperty("sparseTensorType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.SparseTensor.verify(message.sparseTensorType); + if (error) + return "sparseTensorType." + error; + } + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + if (!$util.isString(message.denotation)) + return "denotation: string expected"; + return null; + }; + + /** + * Creates a TypeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto} TypeProto + */ + TypeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto) + return object; + var message = new $root.onnx.TypeProto(); + if (object.tensorType != null) { + if (typeof object.tensorType !== "object") + throw TypeError(".onnx.TypeProto.tensorType: object expected"); + message.tensorType = $root.onnx.TypeProto.Tensor.fromObject(object.tensorType); + } + if (object.sequenceType != null) { + if (typeof object.sequenceType !== "object") + throw TypeError(".onnx.TypeProto.sequenceType: object expected"); + message.sequenceType = $root.onnx.TypeProto.Sequence.fromObject(object.sequenceType); + } + if (object.mapType != null) { + if (typeof object.mapType !== "object") + throw TypeError(".onnx.TypeProto.mapType: object expected"); + message.mapType = $root.onnx.TypeProto.Map.fromObject(object.mapType); + } + if (object.optionalType != null) { + if (typeof object.optionalType !== "object") + throw TypeError(".onnx.TypeProto.optionalType: object expected"); + message.optionalType = $root.onnx.TypeProto.Optional.fromObject(object.optionalType); + } + if (object.sparseTensorType != null) { + if (typeof object.sparseTensorType !== "object") + throw TypeError(".onnx.TypeProto.sparseTensorType: object expected"); + message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.fromObject(object.sparseTensorType); + } + if (object.denotation != null) + message.denotation = String(object.denotation); + return message; + }; + + /** + * Creates a plain object from a TypeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto + * @static + * @param {onnx.TypeProto} message TypeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TypeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.denotation = ""; + if (message.tensorType != null && message.hasOwnProperty("tensorType")) { + object.tensorType = $root.onnx.TypeProto.Tensor.toObject(message.tensorType, options); + if (options.oneofs) + object.value = "tensorType"; + } + if (message.sequenceType != null && message.hasOwnProperty("sequenceType")) { + object.sequenceType = $root.onnx.TypeProto.Sequence.toObject(message.sequenceType, options); + if (options.oneofs) + object.value = "sequenceType"; + } + if (message.mapType != null && message.hasOwnProperty("mapType")) { + object.mapType = $root.onnx.TypeProto.Map.toObject(message.mapType, options); + if (options.oneofs) + object.value = "mapType"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + object.denotation = message.denotation; + if (message.sparseTensorType != null && message.hasOwnProperty("sparseTensorType")) { + object.sparseTensorType = $root.onnx.TypeProto.SparseTensor.toObject(message.sparseTensorType, options); + if (options.oneofs) + object.value = "sparseTensorType"; + } + if (message.optionalType != null && message.hasOwnProperty("optionalType")) { + object.optionalType = $root.onnx.TypeProto.Optional.toObject(message.optionalType, options); + if (options.oneofs) + object.value = "optionalType"; + } + return object; + }; + + /** + * Converts this TypeProto to JSON. + * @function toJSON + * @memberof onnx.TypeProto + * @instance + * @returns {Object.} JSON object + */ + TypeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TypeProto + * @function getTypeUrl + * @memberof onnx.TypeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TypeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto"; + }; + + TypeProto.Tensor = (function() { + + /** + * Properties of a Tensor. + * @memberof onnx.TypeProto + * @interface ITensor + * @property {number|null} [elemType] Tensor elemType + * @property {onnx.ITensorShapeProto|null} [shape] Tensor shape + */ + + /** + * Constructs a new Tensor. + * @memberof onnx.TypeProto + * @classdesc Represents a Tensor. + * @implements ITensor + * @constructor + * @param {onnx.TypeProto.ITensor=} [properties] Properties to set + */ + function Tensor(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Tensor elemType. + * @member {number} elemType + * @memberof onnx.TypeProto.Tensor + * @instance + */ + Tensor.prototype.elemType = 0; + + /** + * Tensor shape. + * @member {onnx.ITensorShapeProto|null|undefined} shape + * @memberof onnx.TypeProto.Tensor + * @instance + */ + Tensor.prototype.shape = null; + + /** + * Creates a new Tensor instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor=} [properties] Properties to set + * @returns {onnx.TypeProto.Tensor} Tensor instance + */ + Tensor.create = function create(properties) { + return new Tensor(properties); + }; + + /** + * Encodes the specified Tensor message. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Tensor.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + writer.uint32(/* id 1, wireType 0 =*/8).int32(message.elemType); + if (message.shape != null && Object.hasOwnProperty.call(message, "shape")) + $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Tensor message, length delimited. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Tensor.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Tensor message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Tensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Tensor} Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Tensor.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Tensor(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = reader.int32(); + break; + } + case 2: { + message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Tensor message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Tensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Tensor} Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Tensor.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Tensor message. + * @function verify + * @memberof onnx.TypeProto.Tensor + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Tensor.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) + if (!$util.isInteger(message.elemType)) + return "elemType: integer expected"; + if (message.shape != null && message.hasOwnProperty("shape")) { + var error = $root.onnx.TensorShapeProto.verify(message.shape); + if (error) + return "shape." + error; + } + return null; + }; + + /** + * Creates a Tensor message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Tensor + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Tensor} Tensor + */ + Tensor.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Tensor) + return object; + var message = new $root.onnx.TypeProto.Tensor(); + if (object.elemType != null) + message.elemType = object.elemType | 0; + if (object.shape != null) { + if (typeof object.shape !== "object") + throw TypeError(".onnx.TypeProto.Tensor.shape: object expected"); + message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); + } + return message; + }; + + /** + * Creates a plain object from a Tensor message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.Tensor} message Tensor + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Tensor.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.elemType = 0; + object.shape = null; + } + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = message.elemType; + if (message.shape != null && message.hasOwnProperty("shape")) + object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); + return object; + }; + + /** + * Converts this Tensor to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Tensor + * @instance + * @returns {Object.} JSON object + */ + Tensor.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Tensor + * @function getTypeUrl + * @memberof onnx.TypeProto.Tensor + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Tensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Tensor"; + }; + + return Tensor; + })(); + + TypeProto.Sequence = (function() { + + /** + * Properties of a Sequence. + * @memberof onnx.TypeProto + * @interface ISequence + * @property {onnx.ITypeProto|null} [elemType] Sequence elemType + */ + + /** + * Constructs a new Sequence. + * @memberof onnx.TypeProto + * @classdesc Represents a Sequence. + * @implements ISequence + * @constructor + * @param {onnx.TypeProto.ISequence=} [properties] Properties to set + */ + function Sequence(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Sequence elemType. + * @member {onnx.ITypeProto|null|undefined} elemType + * @memberof onnx.TypeProto.Sequence + * @instance + */ + Sequence.prototype.elemType = null; + + /** + * Creates a new Sequence instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence=} [properties] Properties to set + * @returns {onnx.TypeProto.Sequence} Sequence instance + */ + Sequence.create = function create(properties) { + return new Sequence(properties); + }; + + /** + * Encodes the specified Sequence message. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Sequence.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Sequence message, length delimited. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Sequence.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Sequence message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Sequence + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Sequence} Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Sequence.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Sequence(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Sequence message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Sequence + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Sequence} Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Sequence.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Sequence message. + * @function verify + * @memberof onnx.TypeProto.Sequence + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Sequence.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) { + var error = $root.onnx.TypeProto.verify(message.elemType); + if (error) + return "elemType." + error; + } + return null; + }; + + /** + * Creates a Sequence message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Sequence + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Sequence} Sequence + */ + Sequence.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Sequence) + return object; + var message = new $root.onnx.TypeProto.Sequence(); + if (object.elemType != null) { + if (typeof object.elemType !== "object") + throw TypeError(".onnx.TypeProto.Sequence.elemType: object expected"); + message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); + } + return message; + }; + + /** + * Creates a plain object from a Sequence message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.Sequence} message Sequence + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Sequence.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.elemType = null; + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); + return object; + }; + + /** + * Converts this Sequence to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Sequence + * @instance + * @returns {Object.} JSON object + */ + Sequence.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Sequence + * @function getTypeUrl + * @memberof onnx.TypeProto.Sequence + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Sequence.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Sequence"; + }; + + return Sequence; + })(); + + TypeProto.Map = (function() { + + /** + * Properties of a Map. + * @memberof onnx.TypeProto + * @interface IMap + * @property {number|null} [keyType] Map keyType + * @property {onnx.ITypeProto|null} [valueType] Map valueType + */ + + /** + * Constructs a new Map. + * @memberof onnx.TypeProto + * @classdesc Represents a Map. + * @implements IMap + * @constructor + * @param {onnx.TypeProto.IMap=} [properties] Properties to set + */ + function Map(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Map keyType. + * @member {number} keyType + * @memberof onnx.TypeProto.Map + * @instance + */ + Map.prototype.keyType = 0; + + /** + * Map valueType. + * @member {onnx.ITypeProto|null|undefined} valueType + * @memberof onnx.TypeProto.Map + * @instance + */ + Map.prototype.valueType = null; + + /** + * Creates a new Map instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap=} [properties] Properties to set + * @returns {onnx.TypeProto.Map} Map instance + */ + Map.create = function create(properties) { + return new Map(properties); + }; + + /** + * Encodes the specified Map message. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap} message Map message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Map.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.keyType != null && Object.hasOwnProperty.call(message, "keyType")) + writer.uint32(/* id 1, wireType 0 =*/8).int32(message.keyType); + if (message.valueType != null && Object.hasOwnProperty.call(message, "valueType")) + $root.onnx.TypeProto.encode(message.valueType, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Map message, length delimited. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap} message Map message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Map.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Map message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Map + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Map} Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Map.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Map(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.keyType = reader.int32(); + break; + } + case 2: { + message.valueType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Map message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Map + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Map} Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Map.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Map message. + * @function verify + * @memberof onnx.TypeProto.Map + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Map.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.keyType != null && message.hasOwnProperty("keyType")) + if (!$util.isInteger(message.keyType)) + return "keyType: integer expected"; + if (message.valueType != null && message.hasOwnProperty("valueType")) { + var error = $root.onnx.TypeProto.verify(message.valueType); + if (error) + return "valueType." + error; + } + return null; + }; + + /** + * Creates a Map message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Map + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Map} Map + */ + Map.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Map) + return object; + var message = new $root.onnx.TypeProto.Map(); + if (object.keyType != null) + message.keyType = object.keyType | 0; + if (object.valueType != null) { + if (typeof object.valueType !== "object") + throw TypeError(".onnx.TypeProto.Map.valueType: object expected"); + message.valueType = $root.onnx.TypeProto.fromObject(object.valueType); + } + return message; + }; + + /** + * Creates a plain object from a Map message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.Map} message Map + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Map.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.keyType = 0; + object.valueType = null; + } + if (message.keyType != null && message.hasOwnProperty("keyType")) + object.keyType = message.keyType; + if (message.valueType != null && message.hasOwnProperty("valueType")) + object.valueType = $root.onnx.TypeProto.toObject(message.valueType, options); + return object; + }; + + /** + * Converts this Map to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Map + * @instance + * @returns {Object.} JSON object + */ + Map.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Map + * @function getTypeUrl + * @memberof onnx.TypeProto.Map + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Map.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Map"; + }; + + return Map; + })(); + + TypeProto.Optional = (function() { + + /** + * Properties of an Optional. + * @memberof onnx.TypeProto + * @interface IOptional + * @property {onnx.ITypeProto|null} [elemType] Optional elemType + */ + + /** + * Constructs a new Optional. + * @memberof onnx.TypeProto + * @classdesc Represents an Optional. + * @implements IOptional + * @constructor + * @param {onnx.TypeProto.IOptional=} [properties] Properties to set + */ + function Optional(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Optional elemType. + * @member {onnx.ITypeProto|null|undefined} elemType + * @memberof onnx.TypeProto.Optional + * @instance + */ + Optional.prototype.elemType = null; + + /** + * Creates a new Optional instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional=} [properties] Properties to set + * @returns {onnx.TypeProto.Optional} Optional instance + */ + Optional.create = function create(properties) { + return new Optional(properties); + }; + + /** + * Encodes the specified Optional message. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Optional.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Optional message, length delimited. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Optional.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an Optional message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Optional + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Optional} Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Optional.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Optional(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an Optional message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Optional + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Optional} Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Optional.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an Optional message. + * @function verify + * @memberof onnx.TypeProto.Optional + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Optional.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) { + var error = $root.onnx.TypeProto.verify(message.elemType); + if (error) + return "elemType." + error; + } + return null; + }; + + /** + * Creates an Optional message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Optional + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Optional} Optional + */ + Optional.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Optional) + return object; + var message = new $root.onnx.TypeProto.Optional(); + if (object.elemType != null) { + if (typeof object.elemType !== "object") + throw TypeError(".onnx.TypeProto.Optional.elemType: object expected"); + message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); + } + return message; + }; + + /** + * Creates a plain object from an Optional message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.Optional} message Optional + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Optional.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.elemType = null; + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); + return object; + }; + + /** + * Converts this Optional to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Optional + * @instance + * @returns {Object.} JSON object + */ + Optional.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Optional + * @function getTypeUrl + * @memberof onnx.TypeProto.Optional + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Optional.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Optional"; + }; + + return Optional; + })(); + + TypeProto.SparseTensor = (function() { + + /** + * Properties of a SparseTensor. + * @memberof onnx.TypeProto + * @interface ISparseTensor + * @property {number|null} [elemType] SparseTensor elemType + * @property {onnx.ITensorShapeProto|null} [shape] SparseTensor shape + */ + + /** + * Constructs a new SparseTensor. + * @memberof onnx.TypeProto + * @classdesc Represents a SparseTensor. + * @implements ISparseTensor + * @constructor + * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set + */ + function SparseTensor(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * SparseTensor elemType. + * @member {number} elemType + * @memberof onnx.TypeProto.SparseTensor + * @instance + */ + SparseTensor.prototype.elemType = 0; + + /** + * SparseTensor shape. + * @member {onnx.ITensorShapeProto|null|undefined} shape + * @memberof onnx.TypeProto.SparseTensor + * @instance + */ + SparseTensor.prototype.shape = null; + + /** + * Creates a new SparseTensor instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set + * @returns {onnx.TypeProto.SparseTensor} SparseTensor instance + */ + SparseTensor.create = function create(properties) { + return new SparseTensor(properties); + }; + + /** + * Encodes the specified SparseTensor message. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensor.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + writer.uint32(/* id 1, wireType 0 =*/8).int32(message.elemType); + if (message.shape != null && Object.hasOwnProperty.call(message, "shape")) + $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified SparseTensor message, length delimited. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensor.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a SparseTensor message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensor.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.SparseTensor(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = reader.int32(); + break; + } + case 2: { + message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a SparseTensor message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensor.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a SparseTensor message. + * @function verify + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + SparseTensor.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) + if (!$util.isInteger(message.elemType)) + return "elemType: integer expected"; + if (message.shape != null && message.hasOwnProperty("shape")) { + var error = $root.onnx.TensorShapeProto.verify(message.shape); + if (error) + return "shape." + error; + } + return null; + }; + + /** + * Creates a SparseTensor message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + */ + SparseTensor.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.SparseTensor) + return object; + var message = new $root.onnx.TypeProto.SparseTensor(); + if (object.elemType != null) + message.elemType = object.elemType | 0; + if (object.shape != null) { + if (typeof object.shape !== "object") + throw TypeError(".onnx.TypeProto.SparseTensor.shape: object expected"); + message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); + } + return message; + }; + + /** + * Creates a plain object from a SparseTensor message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.SparseTensor} message SparseTensor + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + SparseTensor.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.elemType = 0; + object.shape = null; + } + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = message.elemType; + if (message.shape != null && message.hasOwnProperty("shape")) + object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); + return object; + }; + + /** + * Converts this SparseTensor to JSON. + * @function toJSON + * @memberof onnx.TypeProto.SparseTensor + * @instance + * @returns {Object.} JSON object + */ + SparseTensor.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for SparseTensor + * @function getTypeUrl + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + SparseTensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.SparseTensor"; + }; + + return SparseTensor; + })(); + + return TypeProto; + })(); + + onnx.OperatorSetIdProto = (function() { + + /** + * Properties of an OperatorSetIdProto. + * @memberof onnx + * @interface IOperatorSetIdProto + * @property {string|null} [domain] OperatorSetIdProto domain + * @property {number|Long|null} [version] OperatorSetIdProto version + */ + + /** + * Constructs a new OperatorSetIdProto. + * @memberof onnx + * @classdesc Represents an OperatorSetIdProto. + * @implements IOperatorSetIdProto + * @constructor + * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set + */ + function OperatorSetIdProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * OperatorSetIdProto domain. + * @member {string} domain + * @memberof onnx.OperatorSetIdProto + * @instance + */ + OperatorSetIdProto.prototype.domain = ""; + + /** + * OperatorSetIdProto version. + * @member {number|Long} version + * @memberof onnx.OperatorSetIdProto + * @instance + */ + OperatorSetIdProto.prototype.version = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * Creates a new OperatorSetIdProto instance using the specified properties. + * @function create + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto instance + */ + OperatorSetIdProto.create = function create(properties) { + return new OperatorSetIdProto(properties); + }; + + /** + * Encodes the specified OperatorSetIdProto message. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. + * @function encode + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + OperatorSetIdProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.domain); + if (message.version != null && Object.hasOwnProperty.call(message, "version")) + writer.uint32(/* id 2, wireType 0 =*/16).int64(message.version); + return writer; + }; + + /** + * Encodes the specified OperatorSetIdProto message, length delimited. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + OperatorSetIdProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.OperatorSetIdProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + OperatorSetIdProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.OperatorSetIdProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.domain = reader.string(); + break; + } + case 2: { + message.version = reader.int64(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.OperatorSetIdProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + OperatorSetIdProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an OperatorSetIdProto message. + * @function verify + * @memberof onnx.OperatorSetIdProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + OperatorSetIdProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + if (message.version != null && message.hasOwnProperty("version")) + if (!$util.isInteger(message.version) && !(message.version && $util.isInteger(message.version.low) && $util.isInteger(message.version.high))) + return "version: integer|Long expected"; + return null; + }; + + /** + * Creates an OperatorSetIdProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.OperatorSetIdProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + */ + OperatorSetIdProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.OperatorSetIdProto) + return object; + var message = new $root.onnx.OperatorSetIdProto(); + if (object.domain != null) + message.domain = String(object.domain); + if (object.version != null) + if ($util.Long) + (message.version = $util.Long.fromValue(object.version)).unsigned = false; + else if (typeof object.version === "string") + message.version = parseInt(object.version, 10); + else if (typeof object.version === "number") + message.version = object.version; + else if (typeof object.version === "object") + message.version = new $util.LongBits(object.version.low >>> 0, object.version.high >>> 0).toNumber(); + return message; + }; + + /** + * Creates a plain object from an OperatorSetIdProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.OperatorSetIdProto} message OperatorSetIdProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + OperatorSetIdProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.domain = ""; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.version = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.version = options.longs === String ? "0" : 0; + } + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.version != null && message.hasOwnProperty("version")) + if (typeof message.version === "number") + object.version = options.longs === String ? String(message.version) : message.version; + else + object.version = options.longs === String ? $util.Long.prototype.toString.call(message.version) : options.longs === Number ? new $util.LongBits(message.version.low >>> 0, message.version.high >>> 0).toNumber() : message.version; + return object; + }; + + /** + * Converts this OperatorSetIdProto to JSON. + * @function toJSON + * @memberof onnx.OperatorSetIdProto + * @instance + * @returns {Object.} JSON object + */ + OperatorSetIdProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for OperatorSetIdProto + * @function getTypeUrl + * @memberof onnx.OperatorSetIdProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + OperatorSetIdProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.OperatorSetIdProto"; + }; + + return OperatorSetIdProto; + })(); + + /** + * OperatorStatus enum. + * @name onnx.OperatorStatus + * @enum {number} + * @property {number} EXPERIMENTAL=0 EXPERIMENTAL value + * @property {number} STABLE=1 STABLE value + */ + onnx.OperatorStatus = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "EXPERIMENTAL"] = 0; + values[valuesById[1] = "STABLE"] = 1; + return values; + })(); + + onnx.FunctionProto = (function() { + + /** + * Properties of a FunctionProto. + * @memberof onnx + * @interface IFunctionProto + * @property {string|null} [name] FunctionProto name + * @property {Array.|null} [input] FunctionProto input + * @property {Array.|null} [output] FunctionProto output + * @property {Array.|null} [attribute] FunctionProto attribute + * @property {Array.|null} [attributeProto] FunctionProto attributeProto + * @property {Array.|null} [node] FunctionProto node + * @property {string|null} [docString] FunctionProto docString + * @property {Array.|null} [opsetImport] FunctionProto opsetImport + * @property {string|null} [domain] FunctionProto domain + */ + + /** + * Constructs a new FunctionProto. + * @memberof onnx + * @classdesc Represents a FunctionProto. + * @implements IFunctionProto + * @constructor + * @param {onnx.IFunctionProto=} [properties] Properties to set + */ + function FunctionProto(properties) { + this.input = []; + this.output = []; + this.attribute = []; + this.attributeProto = []; + this.node = []; + this.opsetImport = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * FunctionProto name. + * @member {string} name + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.name = ""; + + /** + * FunctionProto input. + * @member {Array.} input + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.input = $util.emptyArray; + + /** + * FunctionProto output. + * @member {Array.} output + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.output = $util.emptyArray; + + /** + * FunctionProto attribute. + * @member {Array.} attribute + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.attribute = $util.emptyArray; + + /** + * FunctionProto attributeProto. + * @member {Array.} attributeProto + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.attributeProto = $util.emptyArray; + + /** + * FunctionProto node. + * @member {Array.} node + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.node = $util.emptyArray; + + /** + * FunctionProto docString. + * @member {string} docString + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.docString = ""; + + /** + * FunctionProto opsetImport. + * @member {Array.} opsetImport + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.opsetImport = $util.emptyArray; + + /** + * FunctionProto domain. + * @member {string} domain + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.domain = ""; + + /** + * Creates a new FunctionProto instance using the specified properties. + * @function create + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto=} [properties] Properties to set + * @returns {onnx.FunctionProto} FunctionProto instance + */ + FunctionProto.create = function create(properties) { + return new FunctionProto(properties); + }; + + /** + * Encodes the specified FunctionProto message. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. + * @function encode + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + FunctionProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + writer.uint32(/* id 4, wireType 2 =*/34).string(message.input[i]); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + writer.uint32(/* id 5, wireType 2 =*/42).string(message.output[i]); + if (message.attribute != null && message.attribute.length) + for (var i = 0; i < message.attribute.length; ++i) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.attribute[i]); + if (message.node != null && message.node.length) + for (var i = 0; i < message.node.length; ++i) + $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 7, wireType 2 =*/58).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 8, wireType 2 =*/66).string(message.docString); + if (message.opsetImport != null && message.opsetImport.length) + for (var i = 0; i < message.opsetImport.length; ++i) + $root.onnx.OperatorSetIdProto.encode(message.opsetImport[i], writer.uint32(/* id 9, wireType 2 =*/74).fork()).ldelim(); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 10, wireType 2 =*/82).string(message.domain); + if (message.attributeProto != null && message.attributeProto.length) + for (var i = 0; i < message.attributeProto.length; ++i) + $root.onnx.AttributeProto.encode(message.attributeProto[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified FunctionProto message, length delimited. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + FunctionProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a FunctionProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.FunctionProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.FunctionProto} FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + FunctionProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.FunctionProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 4: { + if (!(message.input && message.input.length)) + message.input = []; + message.input.push(reader.string()); + break; + } + case 5: { + if (!(message.output && message.output.length)) + message.output = []; + message.output.push(reader.string()); + break; + } + case 6: { + if (!(message.attribute && message.attribute.length)) + message.attribute = []; + message.attribute.push(reader.string()); + break; + } + case 11: { + if (!(message.attributeProto && message.attributeProto.length)) + message.attributeProto = []; + message.attributeProto.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); + break; + } + case 7: { + if (!(message.node && message.node.length)) + message.node = []; + message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); + break; + } + case 8: { + message.docString = reader.string(); + break; + } + case 9: { + if (!(message.opsetImport && message.opsetImport.length)) + message.opsetImport = []; + message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); + break; + } + case 10: { + message.domain = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a FunctionProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.FunctionProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.FunctionProto} FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + FunctionProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a FunctionProto message. + * @function verify + * @memberof onnx.FunctionProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + FunctionProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) + return "input: array expected"; + for (var i = 0; i < message.input.length; ++i) + if (!$util.isString(message.input[i])) + return "input: string[] expected"; + } + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) + return "output: array expected"; + for (var i = 0; i < message.output.length; ++i) + if (!$util.isString(message.output[i])) + return "output: string[] expected"; + } + if (message.attribute != null && message.hasOwnProperty("attribute")) { + if (!Array.isArray(message.attribute)) + return "attribute: array expected"; + for (var i = 0; i < message.attribute.length; ++i) + if (!$util.isString(message.attribute[i])) + return "attribute: string[] expected"; + } + if (message.attributeProto != null && message.hasOwnProperty("attributeProto")) { + if (!Array.isArray(message.attributeProto)) + return "attributeProto: array expected"; + for (var i = 0; i < message.attributeProto.length; ++i) { + var error = $root.onnx.AttributeProto.verify(message.attributeProto[i]); + if (error) + return "attributeProto." + error; + } + } + if (message.node != null && message.hasOwnProperty("node")) { + if (!Array.isArray(message.node)) + return "node: array expected"; + for (var i = 0; i < message.node.length; ++i) { + var error = $root.onnx.NodeProto.verify(message.node[i]); + if (error) + return "node." + error; + } + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.opsetImport != null && message.hasOwnProperty("opsetImport")) { + if (!Array.isArray(message.opsetImport)) + return "opsetImport: array expected"; + for (var i = 0; i < message.opsetImport.length; ++i) { + var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); + if (error) + return "opsetImport." + error; + } + } + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + return null; + }; + + /** + * Creates a FunctionProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.FunctionProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.FunctionProto} FunctionProto + */ + FunctionProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.FunctionProto) + return object; + var message = new $root.onnx.FunctionProto(); + if (object.name != null) + message.name = String(object.name); + if (object.input) { + if (!Array.isArray(object.input)) + throw TypeError(".onnx.FunctionProto.input: array expected"); + message.input = []; + for (var i = 0; i < object.input.length; ++i) + message.input[i] = String(object.input[i]); + } + if (object.output) { + if (!Array.isArray(object.output)) + throw TypeError(".onnx.FunctionProto.output: array expected"); + message.output = []; + for (var i = 0; i < object.output.length; ++i) + message.output[i] = String(object.output[i]); + } + if (object.attribute) { + if (!Array.isArray(object.attribute)) + throw TypeError(".onnx.FunctionProto.attribute: array expected"); + message.attribute = []; + for (var i = 0; i < object.attribute.length; ++i) + message.attribute[i] = String(object.attribute[i]); + } + if (object.attributeProto) { + if (!Array.isArray(object.attributeProto)) + throw TypeError(".onnx.FunctionProto.attributeProto: array expected"); + message.attributeProto = []; + for (var i = 0; i < object.attributeProto.length; ++i) { + if (typeof object.attributeProto[i] !== "object") + throw TypeError(".onnx.FunctionProto.attributeProto: object expected"); + message.attributeProto[i] = $root.onnx.AttributeProto.fromObject(object.attributeProto[i]); + } + } + if (object.node) { + if (!Array.isArray(object.node)) + throw TypeError(".onnx.FunctionProto.node: array expected"); + message.node = []; + for (var i = 0; i < object.node.length; ++i) { + if (typeof object.node[i] !== "object") + throw TypeError(".onnx.FunctionProto.node: object expected"); + message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); + } + } + if (object.docString != null) + message.docString = String(object.docString); + if (object.opsetImport) { + if (!Array.isArray(object.opsetImport)) + throw TypeError(".onnx.FunctionProto.opsetImport: array expected"); + message.opsetImport = []; + for (var i = 0; i < object.opsetImport.length; ++i) { + if (typeof object.opsetImport[i] !== "object") + throw TypeError(".onnx.FunctionProto.opsetImport: object expected"); + message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); + } + } + if (object.domain != null) + message.domain = String(object.domain); + return message; + }; + + /** + * Creates a plain object from a FunctionProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.FunctionProto + * @static + * @param {onnx.FunctionProto} message FunctionProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + FunctionProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.input = []; + object.output = []; + object.attribute = []; + object.node = []; + object.opsetImport = []; + object.attributeProto = []; + } + if (options.defaults) { + object.name = ""; + object.docString = ""; + object.domain = ""; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = message.input[j]; + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = message.output[j]; + } + if (message.attribute && message.attribute.length) { + object.attribute = []; + for (var j = 0; j < message.attribute.length; ++j) + object.attribute[j] = message.attribute[j]; + } + if (message.node && message.node.length) { + object.node = []; + for (var j = 0; j < message.node.length; ++j) + object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.opsetImport && message.opsetImport.length) { + object.opsetImport = []; + for (var j = 0; j < message.opsetImport.length; ++j) + object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); + } + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.attributeProto && message.attributeProto.length) { + object.attributeProto = []; + for (var j = 0; j < message.attributeProto.length; ++j) + object.attributeProto[j] = $root.onnx.AttributeProto.toObject(message.attributeProto[j], options); + } + return object; + }; + + /** + * Converts this FunctionProto to JSON. + * @function toJSON + * @memberof onnx.FunctionProto + * @instance + * @returns {Object.} JSON object + */ + FunctionProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for FunctionProto + * @function getTypeUrl + * @memberof onnx.FunctionProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + FunctionProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.FunctionProto"; + }; + + return FunctionProto; + })(); + + return onnx; +})(); + +module.exports = $root; diff --git a/js/node/test/test-utils.ts b/js/node/test/test-utils.ts index 968e8a1881810..3eef90356a335 100644 --- a/js/node/test/test-utils.ts +++ b/js/node/test/test-utils.ts @@ -4,10 +4,11 @@ import assert from 'assert'; import * as fs from 'fs-extra'; import {jsonc} from 'jsonc'; -import * as onnx_proto from 'onnx-proto'; import {InferenceSession, Tensor} from 'onnxruntime-common'; import * as path from 'path'; +import * as onnx_proto from './ort-schema/protobuf/onnx'; + export const TEST_ROOT = __dirname; export const TEST_DATA_ROOT = path.join(TEST_ROOT, 'testdata'); diff --git a/js/package-lock.json b/js/package-lock.json index c87a58a3196d6..c16a8b59a3a6f 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -3391,9 +3391,9 @@ } }, "node_modules/normalize-package-data/node_modules/semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true, "bin": { "semver": "bin/semver" @@ -7011,9 +7011,9 @@ }, "dependencies": { "semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true } } diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 0b82a9c031baa..b246e19137888 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -20,6 +20,7 @@ Do not modify directly.* | Asinh | ai.onnx(9+) | | | Atan | ai.onnx(7+) | | | Atanh | ai.onnx(9+) | | +| Attention | com.microsoft(1+) | need implementing mask and past/present | | AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(7-9,10,11+) | need perf optimization; need implementing activation | | BiasAdd | com.microsoft(1+) | | | BiasSplitGelu | com.microsoft(1+) | | @@ -61,6 +62,7 @@ Do not modify directly.* | MemcpyFromHost | ai.onnx(1+) | | | MemcpyToHost | ai.onnx(1+) | | | Mul | ai.onnx(7-12,13,14+) | | +| MultiHeadAttention | com.microsoft(1+) | need implementing mask and past/present | | Neg | ai.onnx(6-12,13+) | | | Not | ai.onnx(1+) | | | Pad | ai.onnx(2-10,11-12,13-17,18,19+) | | diff --git a/js/web/lib/onnxjs/attribute-with-cache-key.ts b/js/web/lib/onnxjs/attribute-with-cache-key.ts index 6608b00471e77..5d47570f267a6 100644 --- a/js/web/lib/onnxjs/attribute-with-cache-key.ts +++ b/js/web/lib/onnxjs/attribute-with-cache-key.ts @@ -6,13 +6,13 @@ class AttributeWithCacheKeyImpl { Object.assign(this, attribute); } - private _cacheKey: string; + private key: string; public get cacheKey(): string { - if (!this._cacheKey) { - this._cacheKey = + if (!this.key) { + this.key = Object.getOwnPropertyNames(this).sort().map(name => `${(this as Record)[name]}`).join(';'); } - return this._cacheKey; + return this.key; } } diff --git a/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts b/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts index adba0fb9d022d..ad56b92c1d869 100644 --- a/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts +++ b/js/web/lib/wasm/jsep/webgpu/attribute-with-cache-key.ts @@ -6,13 +6,13 @@ class AttributeWithCacheKeyImpl { Object.assign(this, attribute); } - private _cacheKey: string; + private key: string; public get cacheKey(): string { - if (!this._cacheKey) { - this._cacheKey = + if (!this.key) { + this.key = Object.getOwnPropertyNames(this).sort().map(name => `${(this as Record)[name]}`).join(';'); } - return this._cacheKey; + return this.key; } } diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index a4d51e68b6a25..9f5dceb8f4726 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax'; +import {attention, parseAttentionAttributes} from './ops/attention'; import {biasAdd} from './ops/bias-add'; import {biasSplitGelu} from './ops/bias-split-gelu'; import * as binaryOps from './ops/binary-op'; @@ -16,6 +17,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm'; import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm'; import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm'; import {matMul} from './ops/matmul'; +import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; import {pad, parsePadAttributes} from './ops/pad'; import * as pool from './ops/pool'; import {range} from './ops/range'; @@ -46,6 +48,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Asinh', [unaryOps.asinh]], ['Atan', [unaryOps.atan]], ['Atanh', [unaryOps.atanh]], + ['Attention', [attention, parseAttentionAttributes]], // TODO: support new attributes for AveragePool-10 ['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]], ['BiasAdd', [biasAdd]], @@ -86,6 +89,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new // TODO: support new attributes for MaxPool-8 and MaxPool-10 ['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]], ['Mul', [binaryOps.mul]], + ['MultiHeadAttention', [multiHeadAttention, parseMultiHeadAttentionAttributes]], ['Neg', [unaryOps.neg]], ['Not', [unaryOps.not]], ['Pad', [pad, parsePadAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts new file mode 100644 index 0000000000000..e1f2a47301bfb --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -0,0 +1,635 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType} from '../types'; + +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; + +export const enum AttentionQkvFormat { + unknown, // enum value not set, or depends on qkv projection implementation details + qkvBNSH, // for non-packed qkv, permuted + qkvBSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention + qkvBSN3H, // for TRT fused attention, qkv are packed + qkvBNSHqkvBS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) + qKvBSNHxBSN2H, // for TRT fused cross attention, kv are packed + qkvTNH, // for memory efficient attention, qkv are not packed, and paddings are removed. + qkvTN3H, // for TRT fused attention, qkv are packed and paddings are removed +} + +export const enum AttentionMaskType { + none, // No mask + mask1dKeySeqLen, // [batch_size], key sequence length + mask1dEndStart, // [2 * batch_size] with end positions and start positions + mask1DKeySeqLenStart, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0], + // ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ..., + // key_start[batch_size - 1], key_end[batch_size - 1]] + mask2dDummy, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask. + mask2dKeyPadding, // [batch_size, total_sequence_length] + mask3dAttention, // [batch_size, sequence_length, total_sequence_length] + mask4dMegatron, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length] + maskUnknown +} + +export interface AttentionParameters { + batchSize: number; + sequenceLength: number; + pastSequenceLength: number; + kvSequenceLength: number; + totalSequenceLength: number; + maxSequenceLength: number; + inputHiddenSize: number; + hiddenSize: number; + vHiddenSize: number; + headSize: number; + vHeadSize: number; + numHeads: number; + isUnidirectional: boolean; + pastPresentShareBuffer: boolean; + maskFilterValue: number; + maskType: AttentionMaskType; + scale: number; + broadcastResPosBias: boolean; + passPastInKv: boolean; + qkvFormat: AttentionQkvFormat; +} + +export interface AttentionAttrs { + numHeads: number; + isUnidirectional: number; + maskFilterValue: number; + scale: number; + doRotary: number; + qkvHiddenSizes: number[]; + pastPresentShareBuffer: boolean; +} + +const validateAttentionInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { + // Abbreviation and Meanings: + // B: batch_size + // S: sequence_length (input sequence length of query) + // P: past_sequence_length (past sequence length of key or value) + // L: kv_sequence_length (input sequence length of key or value) + // M: max_sequence_length + // T: total_sequence_length = past_sequence_length + kv_sequence_length + // N: num_heads + // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size + // H_v: v_head_size + // D_i: input hidden size + // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size + // D_v: v_hidden_size = num_heads * v_head_size + + // When past state is used, Q, K and V should have same hidden size (unless we split it into past_key and past_value). + + // Input shapes: + // input (Q/K/V) : (B, S, D_i) + // weights (Q/K/V) : (D_i, D + D + D_v) + // bias (Q/K/V) : (D + D + D_v) + // mask_index : see below + // past (K/V) : (2, B, N, P, H) or NULL + // relative_position_bias : (B, N, S, T) or NULL + + // For mask_index, the following shapes are supported: + // NULL, (B, 1), (1, 1) + // (B), (2 * B), (3 * B + 2) + // (B, T) + // (B, S, T) + // (B, 1, M, M) + // + // When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger + // than hidden dimension of Q, K and V. + + const input = inputs[0]; + const weights = inputs[1]; + const bias = inputs[2]; + const maskIndex = inputs[3]; + const past = inputs[4]; + const relativePositionBias = inputs[5]; + + if (past && relativePositionBias) { + throw new Error('Attention cannot have both past and relative_position_bias'); + } + + if (input.dims.length !== 3) { + throw new Error('Input "input" must have 3 dimensions'); + } + + const batchSize = input.dims[0]; + const sequenceLength = input.dims[1]; + const inputHiddenSize = input.dims[2]; + + if (bias.dims.length !== 1) { + throw new Error('Input "bias" is expected to have 1 dimensions'); + } + + if (weights.dims.length !== 2) { + throw new Error('Input "weights" is expected to have 2 dimensions'); + } + + if (weights.dims[0] !== inputHiddenSize) { + throw new Error('Input 1 dimension 0 should have same length as dimension 2 of input 0'); + } + + if (bias.dims[0] !== weights.dims[1]) { + throw new Error('Input "bias" dimension 0 should have same length as dimension 1 of input "weights"'); + } + + let qHiddenSize = bias.dims[0] / 3; + let kHiddenSize = qHiddenSize; + let vHiddenSize = kHiddenSize; + if (attributes.qkvHiddenSizes.length > 0) { + if (attributes.qkvHiddenSizes.length !== 3) { + throw new Error('qkv_hidden_sizes attribute should have 3 elements'); + } + for (const sz of attributes.qkvHiddenSizes) { + if (sz % attributes.numHeads !== 0) { + throw new Error('qkv_hidden_sizes should be divisible by num_heads'); + } + } + + qHiddenSize = attributes.qkvHiddenSizes[0]; + kHiddenSize = attributes.qkvHiddenSizes[1]; + vHiddenSize = attributes.qkvHiddenSizes[2]; + } + + const kvSequenceLength = sequenceLength; + + if (qHiddenSize !== kHiddenSize) { + throw new Error('qkv_hidden_sizes first element should be same as the second'); + } + + if (bias.dims[0] !== qHiddenSize + kHiddenSize + vHiddenSize) { + throw new Error('Input "bias" dimension 0 should have same length as sum of Q/K/V hidden sizes'); + } + + let pastSequenceLength = 0; + if (past) { + if (kHiddenSize !== vHiddenSize) { + throw new Error('Input "past" expect k_hidden_size == v_hidden_size'); + } + if (past.dims.length !== 5) { + throw new Error('Input "past" must have 5 dimensions'); + } + if (past.dims[0] !== 2) { + throw new Error('Input "past" first dimension must be 2'); + } + if (past.dims[1] !== batchSize) { + throw new Error('Input "past" second dimension must be batch_size'); + } + if (past.dims[2] !== attributes.numHeads) { + throw new Error('Input "past" third dimension must be num_heads'); + } + if (past.dims[4] !== kHiddenSize / attributes.numHeads) { + throw new Error('Input "past" fifth dimension must be k_hidden_size / num_heads'); + } + + if (!attributes.pastPresentShareBuffer) { + pastSequenceLength = past.dims[3]; + } + // TODO: handle past_seq_len + } + + const totalSequenceLength = kvSequenceLength + pastSequenceLength; + const maxSequenceLength = -1; + + const maskType = AttentionMaskType.none; + if (maskIndex) { + // maskType = AttentionMaskType.MASK_UNKNOWN; + // TODO: handle mask + throw new Error('Mask not supported'); + } + + if (past) { + throw new Error('past is not supported'); + } + if (relativePositionBias) { + throw new Error('relativePositionBias is not supported'); + } + + return { + batchSize, + sequenceLength, + pastSequenceLength, + kvSequenceLength, + totalSequenceLength, + maxSequenceLength, + inputHiddenSize, + hiddenSize: qHiddenSize, + vHiddenSize, + headSize: Math.floor(qHiddenSize / attributes.numHeads), + vHeadSize: Math.floor(vHiddenSize / attributes.numHeads), + numHeads: attributes.numHeads, + isUnidirectional: false, + pastPresentShareBuffer: false, + maskFilterValue: attributes.maskFilterValue, + maskType, + scale: attributes.scale, + broadcastResPosBias: false, + passPastInKv: false, + qkvFormat: AttentionQkvFormat.qkvBNSH, + }; +}; + +export const parseAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => + createAttributeWithCacheKey({...attributes}); + +export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView, n: number, d: number) => { + const components = getMaxComponents(d); + const inputHelper = outputVariable('x', input.dataType, input.dims, components); + + let threadMaxValue = 'threadMaxVector'; + if (components === 2) { + threadMaxValue = 'max(threadMaxVector.x, threadMaxVector.y)'; + } else if (components === 4) { + threadMaxValue = 'max(max(threadMaxVector.x, threadMaxVector.y), max(threadMaxVector.z, threadMaxVector.w))'; + } + const dataType = tensorTypeToWsglStorageType(input.dataType); + let WG = 64; + const dComp = d / components; + if (dComp < WG) { + WG = 1; + } else if (dComp / 8 < 64) { + WG = Math.ceil(dComp / 8); + } + const elementsPerWG = Math.ceil(d / components / WG); + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const dInv: ${dataType} = 1 / ${d}; + const dComp = ${d / components}; + var wgMax: array; + var wgSum: array; + + ${shaderHelper.declareVariables(inputHelper)} + @compute @workgroup_size(${WG}, 1, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_index) local_index : u32) { + let localOffset = local_index * ${elementsPerWG}; + let offset: u32 = workgroup_id.x * dComp + localOffset; + + var threadMaxVector = ${fillVector('f32', components, '-3.402823e+38f')}; + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + threadMaxVector = max(${castToF32(dataType, components, 'x[offset + i]')}, threadMaxVector); + } + wgMax[local_index] = ${threadMaxValue}; + workgroupBarrier(); + + var maxValue = -3.402823e+38f; + for (var i = 0u; i < ${WG}; i++) { + maxValue = max(wgMax[i], maxValue); + } + + var sumVector = ${fillVector('f32', components, '0')}; + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + sumVector += exp(${castToF32(dataType, components, 'x[offset + i]')} - maxValue); + } + wgSum[local_index] = ${sumVector('sumVector', components)}; + workgroupBarrier(); + + var sum: f32 = 0; + for (var i = 0u; i < ${WG}; i++) { + sum += wgSum[i]; + } + + if (sum == 0) { + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + x[offset + i] = ${fillVector(dataType, components, 'dInv')}; + } + } else { + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + let f32input = ${castToF32(dataType, components, 'x[offset + i]')}; + x[offset + i] = ${inputHelper.type.value}(exp(f32input - maxValue) / sum); + } + } + }`; + + context.compute( + { + name: 'AttentionProbsSoftmax', + shaderCache: {hint: `${d}`}, + getShaderSource, + getRunData: () => ({ + outputs: [], + dispatchGroup: {x: n}, + }), + }, + {inputs: [input], outputs: []}); +}; + +const computeAttentionProbs = + (context: ComputeContext, q: TensorView, key: TensorView, _bias: TensorView|undefined, + parameters: AttentionParameters, attributes: AttentionAttrs) => { + const probsShape = [ + parameters.batchSize, parameters.numHeads, parameters.sequenceLength, + parameters.kvSequenceLength + parameters.pastSequenceLength + ]; + // TODO: handle mask + + const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; + + const dataType = tensorTypeToWsglStorageType(q.dataType); + + const components = getMaxComponents(parameters.headSize); + const qInput = inputVariable('q', q.dataType, q.dims, components); + const kInput = inputVariable('key', key.dataType, key.dims, components); + const output = outputVariable('output', q.dataType, probsShape); + + const vectorizedHeadSize = parameters.headSize / components; + const M = parameters.sequenceLength; + const N = parameters.totalSequenceLength; + const K = vectorizedHeadSize; + + const TILE_SIZE = 12; + + const dispatch = { + x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE), + y: Math.ceil(parameters.sequenceLength / TILE_SIZE), + z: parameters.batchSize * parameters.numHeads + }; + + const inputs = [q, key]; + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const M: u32 = ${M}u; + const N: u32 = ${N}u; + const K: u32 = ${K}u; + const alpha: ${dataType} = ${alpha}; + const beta: ${dataType} = 1.0; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + var tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + + ${shaderHelper.declareVariables(qInput, kInput, output)} + + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; + + // x holds the N and y holds the M + let headIdx = workgroup_id.z; + let m = workgroup_id.y * TILE_SIZE; + let n = workgroup_id.x * TILE_SIZE; + let lm = m + local_id.y; + let ln = n + local_id.x; + + let qOffset = ${parameters.sequenceLength * vectorizedHeadSize} * headIdx + m * K; + let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * headIdx + n * K; + + var value = ${fillVector(dataType, components)}; + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m + local_id.y < M && w + local_id.x < K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * K + w + local_id.x]; + } + if (n + local_id.y < N && w + local_id.x < K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * K + w + local_id.x]; + } + workgroupBarrier(); + + for (var k: u32 = 0u; k ({ + outputs: [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: dispatch, + }), + getShaderSource, + }, + {inputs, outputs: [-1]})[0]; + + computeInPlaceSoftmax( + context, probs, parameters.batchSize * parameters.numHeads * parameters.sequenceLength, + parameters.totalSequenceLength); + + return probs; + }; + +const computeVxAttentionScore = + (context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => { + const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize]; + + const probsHelper = inputVariable('probs', probs.dataType, probs.dims); + const vHelper = inputVariable('v', v.dataType, v.dims); + const output = outputVariable('output', probs.dataType, outputShape); + + const dataType = tensorTypeToWsglStorageType(probs.dataType); + + const TILE_SIZE = 12; + const dispatch = { + x: Math.ceil(params.vHeadSize / TILE_SIZE), + y: Math.ceil(params.sequenceLength / TILE_SIZE), + z: params.batchSize * params.numHeads + }; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const M: u32 = ${params.sequenceLength}u; + const N: u32 = ${params.vHeadSize}u; + const K: u32 = ${params.totalSequenceLength}u; + const numHeads: u32 = ${params.numHeads}u; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileQ: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + var tileK: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + + ${shaderHelper.declareVariables(probsHelper, vHelper, output)} + + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; + + let headIdx = workgroup_id.z; + let m = workgroup_id.y * TILE_SIZE + local_id.y; + let n = workgroup_id.x * TILE_SIZE + local_id.x; + + let offsetA = headIdx * (M * K) + m * K; + let offsetB = headIdx * (N * K) + n; + + var value = ${dataType}(0); + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m < M && w + local_id.x < K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x]; + } + if (n < N && w + local_id.y < K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * N]; + } + workgroupBarrier(); + for (var k: u32 = 0u; k ({ + outputs: [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: dispatch, + }), + getShaderSource, + }, + {inputs: [probs, v], outputs: [0]})[0]; + }; + +export const applyAttention = + (context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined, + _past: TensorView|undefined, _pastKey: TensorView|undefined, _pastValue: TensorView|undefined, + relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => { + const probs = computeAttentionProbs(context, q, k, relativePositionBias, parameters, attributes); + + computeVxAttentionScore(context, probs, v, parameters); + }; + +const prepare = (context: ComputeContext, parameters: AttentionParameters) => { + const outputShape = [ + parameters.batchSize, + parameters.numHeads, + parameters.sequenceLength, + parameters.headSize, + ]; + + const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); + + const M = parameters.sequenceLength; + const K = parameters.inputHiddenSize; + const N = parameters.headSize; + + const TILE_SIZE = 12; + const dispatch = { + x: Math.ceil(parameters.headSize / TILE_SIZE), + y: Math.ceil(parameters.sequenceLength / TILE_SIZE), + z: parameters.batchSize * parameters.numHeads + }; + + const getShaderSource = () => ` + const M: u32 = ${M}u; + const K: u32 = ${K}u; + const N: u32 = ${N}u; + const numHeads: u32 = ${parameters.numHeads}; + const ldb = ${parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}u; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileInput: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightV: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + + @group(0) @binding(0) var input: array<${dataType}>; + @group(0) @binding(1) var weight: array<${dataType}>; + @group(0) @binding(2) var bias: array<${dataType}>; + @group(0) @binding(3) var outputQ: array<${dataType}>; + @group(0) @binding(4) var outputK: array<${dataType}>; + @group(0) @binding(5) var outputV: array<${dataType}>; + + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; + + let batchIndex = workgroup_id.z / ${parameters.numHeads}; + let headNumber = workgroup_id.z % ${parameters.numHeads}; + let m = workgroup_id.y * TILE_SIZE + local_id.y; + let n = workgroup_id.x * TILE_SIZE + local_id.x; + + let inputOffset = batchIndex * (M * K) + m * K; + let biasOffsetQ = headNumber * ${parameters.headSize}; + let biasOffsetK = ${parameters.hiddenSize} + biasOffsetQ; + let biasOffsetV = ${parameters.hiddenSize} + biasOffsetK; + + var valueQ = ${dataType}(0); + var valueK = ${dataType}(0); + var valueV = ${dataType}(0); + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m < M && w + local_id.x < K) { + tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x]; + } + if (n < N && w + local_id.y < K) { + let offset = n + (w + local_id.y) * ldb; + tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset]; + tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset]; + tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset]; + } + workgroupBarrier(); + for (var k: u32 = 0u; k ({ + outputs: [ + {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, + {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, + {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, + ], + dispatchGroup: dispatch, + }), + getShaderSource, + }, + {inputs, outputs: [-1, -1, -1]}); +}; + +export const attention = (context: ComputeContext, attributes: AttentionAttrs): void => { + const params = validateAttentionInputs(context.inputs, attributes); + + const [q, k, v] = prepare(context, params); + + return applyAttention( + context, q, k, v, context.inputs[4], undefined, undefined, undefined, context.inputs[5], params, attributes); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 38dc14f23682e..014d9d02f6f10 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -646,6 +646,8 @@ export const outputVariable = (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, false, components); +export type UniformsArrayType = Array<{name: string; type: string}>; + /** * A ShaderHelper is a helper class for generating WGSL code. */ @@ -697,6 +699,7 @@ export interface ShaderHelper { * A helper function to register one uniform. Can be called multiple times to register multiple uniforms. */ registerUniform(name: string, type: string): ShaderHelper; + registerUniforms(nameToTypeMap: UniformsArrayType): ShaderHelper; } class ShaderHelperImpl implements ShaderHelper { @@ -755,8 +758,13 @@ class ShaderHelperImpl implements ShaderHelper { return this; } + registerUniforms(additionalUniforms: UniformsArrayType): ShaderHelper { + this.uniforms = this.uniforms.concat(additionalUniforms); + return this; + } + private indicesHelpers: IndicesHelper[] = []; - private uniforms: Array<{name: string; type: string}> = []; + private uniforms: UniformsArrayType = []; private uniformDeclaration(): string { if (this.uniforms.length === 0) { return ''; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts new file mode 100644 index 0000000000000..b7726a36bcaad --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts @@ -0,0 +1,335 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType} from '../types'; + +import {applyAttention, AttentionAttrs, AttentionMaskType, AttentionParameters, AttentionQkvFormat} from './attention'; +import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {createTransposeProgramInfo, TransposeAttributes} from './transpose'; + +const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { + const query = inputs[0]; + const key = inputs[1]; + const value = inputs[2]; + const bias = inputs[3]; + const keyPaddingMask = inputs[4]; + const relativePositionBias = inputs[5]; + const pastKey = inputs[6]; + const pastValue = inputs[7]; + + // Abbreviation and Meanings: + // B: batch_size + // S: sequence_length (input sequence length of query) + // P: past_sequence_length (past sequence length of key or value) + // L: kv_sequence_length (input sequence length of key or value) + // M: max_sequence_length + // T: total_sequence_length = past_sequence_length + kv_sequence_length + // N: num_heads + // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size + // H_v: v_head_size + // D_i: input hidden size + // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size + // D_v: v_hidden_size = num_heads * v_head_size + + // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None + // relative_position_bias : (B, 1, S, L) + // past_key : (B, N, S*, H) + // past_value : (B, N, S*, H) + // When no packing for q/k/v: + // query (Q) : (B, S, D) + // key (K) : (B, L, D) or (B, N, S*, H) + // value (V) : (B, L, D_v) or (B, N, S*, H) + // bias (Q/K/V) : (D + D + D_v) + // When packed kv is used: + // query (Q) : (B, S, D) + // key (K) : (B, L, N, 2, H) + // value (V) : None + // bias (Q/K/V) : None + // When packed qkv is used: + // query (Q) : (B, L, N, 3, H) or (B, S, 3*D) + // key (K) : None + // value (V) : None + // bias (Q/K/V) : None or (D + D + D_v) + + if (query.dims.length !== 3 && query.dims.length !== 5) { + throw new Error('Input query is expected to have 3 or 5 dimensions'); + } + + const dmmhaPacking = false; + const batchSize = query.dims[0]; + const sequenceLength = query.dims[1]; + const hiddenSize = query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : + attributes.numHeads * query.dims[4]; + let kvSequenceLength = sequenceLength; + + let pastSequenceLength = 0; + let maxSequenceLength = 0; + const headSize = Math.floor(hiddenSize / attributes.numHeads); + if (pastKey && pastValue) { + if (pastKey.dims.length !== 4) { + throw new Error('Input "past_key" is expected to have 4 dimensions'); + } + if (pastValue.dims.length !== 4) { + throw new Error('Input "past_value" is expected to have 4 dimensions'); + } + pastSequenceLength = pastKey.dims[2]; + maxSequenceLength = pastKey.dims[2]; + } else if (pastKey || pastValue) { + throw new Error('Input "past_key" and "past_value" shall be both present or both absent'); + } + + let qkvFormat: AttentionQkvFormat; + if (key) { + if (query.dims.length !== 3) { + throw new Error('Input "query" is expected to have 3 dimensions when key is given'); + } + if (key.dims.length < 3 || key.dims.length > 5) { + throw new Error('Input "key" is expected to have 3, 4, or 5 dimensions'); + } + if (query.dims[0] !== key.dims[0]) { + throw new Error('Input "query" and "key" shall have same dim 0 (batch size)'); + } + + if (key.dims.length === 3) { + if (key.dims[2] !== query.dims[2]) { + throw new Error('Input "query" and "key" shall have same dim 2 (hidden_size)'); + } + qkvFormat = AttentionQkvFormat.qkvBSNH; + kvSequenceLength = key.dims[1]; + } else if (key.dims.length === 5) { + if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) { + throw new Error('Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv'); + } + if (value) { + throw new Error('Expect "value" be none when "key" has packed kv format.'); + } + qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H; + kvSequenceLength = key.dims[1]; + } else { // key_dims.size() == 4 (cross-attention with past_key) + if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) { + throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); + } + + qkvFormat = AttentionQkvFormat.unknown; + kvSequenceLength = key.dims[2]; + } + } else { // packed QKV + if (query.dims.length !== 3 && query.dims.length !== 5) { + throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty'); + } + if (query.dims.length === 5 && (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3)) { + throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv'); + } + + qkvFormat = AttentionQkvFormat.qkvBSN3H; + } + + if (bias) { + if (bias.dims.length !== 1) { + throw new Error('Input "bias" is expected to have 1 dimension'); + } + + if (value) { + if (query.dims.length === 5 && query.dims[3] === 2) { + throw new Error('bias is not allowed for packed kv.'); + } + } + } + + let maskType: AttentionMaskType = AttentionMaskType.none; + if (keyPaddingMask) { + maskType = AttentionMaskType.maskUnknown; + const maskDims = keyPaddingMask.dims; + if (maskDims.length === 1) { + if (maskDims[0] === batchSize) { + maskType = AttentionMaskType.mask1dKeySeqLen; + } else if (maskDims[0] === 3 * batchSize + 2) { + maskType = AttentionMaskType.mask1DKeySeqLenStart; + } + } else if (maskDims.length === 2 && maskDims[0] === batchSize && maskDims[1] === kvSequenceLength) { + maskType = AttentionMaskType.mask2dKeyPadding; + } + if (maskType === AttentionMaskType.maskUnknown) { + throw new Error('Input "key_padding_mask" shape shall be (batch_size) or (batch_size, kv_sequence_length)'); + } + throw new Error('Mask not supported'); + } + + let passPastInKv = false; + let vHiddenSize = hiddenSize; + if (value) { + if (value.dims.length !== 3 && value.dims.length !== 4) { + throw new Error('Input "value" is expected to have 3 or 4 dimensions'); + } + + if (query.dims[0] !== value.dims[0]) { + throw new Error('Input "query" and "value" shall have same dim 0 (batch_size)'); + } + + if (value.dims.length === 3) { + if (kvSequenceLength !== value.dims[1]) { + throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)'); + } + vHiddenSize = value.dims[2]; + } else { + if (kvSequenceLength !== value.dims[2]) { + throw new Error('Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)'); + } + vHiddenSize = value.dims[1] * value.dims[3]; + passPastInKv = true; + } + } + + const totalSequenceLength = pastSequenceLength + kvSequenceLength; + const broadcastResPosBias = false; + // if (extraAddQk) { + // if (extraAddQk.dims[0] === 1) { + // broadcastResPosBias = true; + // } + // } + + if (keyPaddingMask) { + throw new Error('Key padding mask is not supported'); + } + if (relativePositionBias) { + throw new Error('extraAddQk is not supported'); + } + if (pastKey) { + throw new Error('pastKey is not supported'); + } + if (pastValue) { + throw new Error('pastValue is not supported'); + } + + return { + batchSize, + sequenceLength, + pastSequenceLength, + kvSequenceLength, + totalSequenceLength, + maxSequenceLength, + inputHiddenSize: 0, + hiddenSize, + vHiddenSize, + headSize, + vHeadSize: Math.floor(vHiddenSize / attributes.numHeads), + numHeads: attributes.numHeads, + isUnidirectional: false, + pastPresentShareBuffer: false, + maskFilterValue: attributes.maskFilterValue, + maskType, + scale: attributes.scale, + broadcastResPosBias, + passPastInKv, + qkvFormat, + }; +}; + + +export const parseMultiHeadAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => + createAttributeWithCacheKey({...attributes}); + +const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [0, 2, 1, 3]}); + +const addBiasTranspose = + (context: ComputeContext, qkv: TensorView, bias: TensorView, batchSize: number, sequenceLength: number, + hiddenSize: number, biasOffset: number) => { + const outputShape = [batchSize, sequenceLength, hiddenSize]; + const outputSize = ShapeUtil.size(outputShape); + + const dataType = tensorTypeToWsglStorageType(qkv.dataType); + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const biasOffset = ${biasOffset}u; + const hiddenSize = ${hiddenSize}u; + + @group(0) @binding(0) var qkv: array<${dataType}>; + @group(0) @binding(1) var bias: array<${dataType}>; + @group(0) @binding(2) var qkv_with_bias: array<${dataType}>; + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + let biasOffsetIdx = (global_idx % hiddenSize) + biasOffset; + + qkv_with_bias[global_idx] = qkv[global_idx] + bias[biasOffsetIdx]; + }`; + + return context.compute( + { + name: 'MultiHeadAttentionAddBias', + shaderCache: {hint: JSON.stringify({batchSize, sequenceLength, hiddenSize, biasOffset})}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: qkv.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + }), + getShaderSource, + }, + {inputs: [qkv, bias], outputs: [-1]})[0]; + }; + +const maybeTransposeToBNSHAndAddBias = + (context: ComputeContext, batchSize: number, numHeads: number, sequenceLength: number, headSize: number, + input: TensorView, bias?: TensorView, biasOffset?: number) => { + // const newDims = []; + + let reshapedInput = input; + if (!bias) { + if (input.dims.length === 3) { + reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]); + } + return context.compute( + createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), + {inputs: [reshapedInput], outputs: [-1]})[0]; + } else { + if (sequenceLength === 1) { + throw new Error('AddBiasReshape is not implemented. Please export your model with packed QKV or KV'); + } else { + reshapedInput = + addBiasTranspose(context, input, bias, batchSize, sequenceLength, numHeads * headSize, biasOffset!); + reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]); + return context.compute( + createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), + {inputs: [reshapedInput], outputs: [-1]})[0]; + } + } + }; + +export const multiHeadAttention = (context: ComputeContext, attributes: AttentionAttrs): void => { + const params = validateInputs(context.inputs, attributes); + + if (context.inputs[0].dims.length === 5) { + throw new Error('Packed QKV is not implemented'); + } + + if (context.inputs[1]?.dims.length === 5) { + throw new Error('Packed KV is not implemented'); + } + + // applyAttention expects BNSH inputs + const kvBNSH = context.inputs[1] && context.inputs[2] && context.inputs[1].dims.length === 4 && + context.inputs[2].dims.length === 4; + + const Q = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, context.inputs[0], + context.inputs[3], 0); + + if (kvBNSH) { + return applyAttention( + context, Q, context.inputs[1], context.inputs[2], context.inputs[4], undefined, undefined, undefined, + context.inputs[5], params, attributes); + } + + const K = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.kvSequenceLength, params.headSize, context.inputs[1], + context.inputs[3], params.hiddenSize); + + const V = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.kvSequenceLength, params.vHeadSize, context.inputs[2], + context.inputs[3], 2 * params.hiddenSize); + + applyAttention( + context, Q, K, V, context.inputs[4], undefined, context.inputs[6], context.inputs[7], context.inputs[5], params, + attributes); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index d607351f69b74..7458579bf4340 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -5,9 +5,9 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo, TensorInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; export interface SliceAttributes extends AttributeWithCacheKey { readonly starts: number[]; @@ -77,17 +77,26 @@ const fixStartEndValues = }; const calculateInputIndicesImpl = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[]): - string => `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { + (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], + enableInputShapeUniforms: boolean): string => + `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { var inputIndices: ${input.type.indices}; var carry = 0u; for (var i = ${inputShape.length}; i >= 0; i--) { + let input_shape_i = ${ + enableInputShapeUniforms ? `uniforms.input_shape${inputShape.length > 1 ? '[i]' : ''}` : 'inputShape[i]'}; + let steps_i = ${ + enableInputShapeUniforms ? `uniforms.steps${inputShape.length > 1 ? '[i]' : ''}` : 'steps[i]'}; + let signs_i = ${ + enableInputShapeUniforms ? `uniforms.signs${inputShape.length > 1 ? '[i]' : ''}` : 'signs[i]'}; + let starts_i = ${ + enableInputShapeUniforms ? `uniforms.starts${inputShape.length > 1 ? '[i]' : ''}` : 'starts[i]'}; var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; - var inputIndex = outputIndex * steps[i] + starts[i] + carry; - carry = inputIndex / inputShape[i]; - inputIndex = inputIndex % inputShape[i]; - if (signs[i] < 0) { - inputIndex = inputShape[i] - inputIndex - 1u + starts[i]; + var inputIndex = outputIndex * steps_i + starts_i + carry; + carry = inputIndex / input_shape_i; + inputIndex = inputIndex % input_shape_i; + if (signs_i < 0) { + inputIndex = input_shape_i - inputIndex - 1u + starts_i; } ${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'} = inputIndex; } @@ -110,6 +119,10 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice const ends = attributes.ends.map((end, i) => fixStartEndValues(end, i, inputShape, axes, steps)); + if (axes.length !== starts.length || axes.length !== ends.length) { + throw new Error('start, ends and axes should have the same number of elements'); + } + if (axes.length !== inputShape.length) { for (let i = 0; i < inputShape.length; ++i) { if (!axes.includes(i)) { @@ -131,40 +144,66 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice array[i] = -step; } }); + // Output rank is expected to be less than or equal to the input rank. + const enableShapeUniforms = enableShapesUniforms(inputs[0].dims.length); + const inputShapeOrRank = enableShapeUniforms ? inputs[0].dims.length : inputs[0].dims; const outputShape = inputShape.slice(0); axes.forEach((axis, _) => { outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]); }); + const outputShapeOrRank = enableShapeUniforms ? outputShape.length : outputShape; const outputTensorInfo: TensorInfo = {dims: outputShape, dataType: inputs[0].dataType}; - const output = outputVariable('output', inputs[0].dataType, outputShape); - const input = inputVariable('input', inputs[0].dataType, inputShape); + const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank); + const input = inputVariable('input', inputs[0].dataType, inputShapeOrRank); const outputSize = ShapeUtil.size(outputShape); + const programUniforms: ProgramUniform[] = []; + const uniforms: UniformsArrayType = []; + if (enableShapeUniforms) { + uniforms.push({name: 'starts', type: starts.length > 1 ? `vec${starts.length}` : 'u32'}); + uniforms.push({name: 'signs', type: signs.length > 1 ? `vec${signs.length}` : 'i32'}); + uniforms.push({name: 'steps', type: steps.length > 1 ? `vec${steps.length}` : 'u32'}); + programUniforms.push({type: 'uint32', data: starts}); + programUniforms.push({type: 'int32', data: signs}); + programUniforms.push({type: 'uint32', data: steps}); + } + uniforms.push({name: 'outputSize', type: 'u32'}); + programUniforms.push({type: 'uint32', data: outputSize}); + if (enableShapeUniforms) { + programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); + programUniforms.push(...createTensorShapeVariables(outputShape)); + } const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(input, output)} - const signs = array(${signs.map(i => `${i}i`).join(',')}); - const starts = array(${starts.map(i => `${i}u`).join(',')}); - const ends = array(${ends.map(i => `${i}u`).join(',')}); - const steps = array(${steps.map(i => `${i}u`).join(',')}); - const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); - - ${calculateInputIndicesImpl(input, output, inputShape, outputShape)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)} + ${enableShapeUniforms ? '' : [ + `const signs = array(${signs.map(i => `${i}i`).join(',')});`, + `const starts = array(${starts.map(i => `${i}u`).join(',')});`, + `const steps = array(${steps.map(i => `${i}u`).join(',')});`, + `const inputShape = array(${inputShape.map(i => `${i}u`).join(',')});` + ].join('\n')} + + ${calculateInputIndicesImpl(input, output, inputShape, outputShape, enableShapeUniforms)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} let outputIndices = ${output.offsetToIndices('global_idx')}; let inputIndices = calculateInputIndices(outputIndices); ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))} }`; return { name: 'Slice', - shaderCache: {hint: `${attributes.cacheKey}|${inputs[4]?.dims ?? ''}`}, + shaderCache: { + hint: enableShapeUniforms ? `${signs.length}_${starts.length}_${steps.length}` : + `${attributes.cacheKey} | ${inputs[4]?.dims ?? ''}`, + inputDependencies: [enableShapeUniforms ? 'rank' : 'dims'] + }, getShaderSource, getRunData: () => ({ outputs: [outputTensorInfo], dispatchGroup: {x: Math.ceil(inputSize / 64 /* workgroup size */)}, + programUniforms }) }; }; diff --git a/js/web/script/generate-webgpu-operator-md.ts b/js/web/script/generate-webgpu-operator-md.ts index 7408f17004f5e..eab8175a941bd 100644 --- a/js/web/script/generate-webgpu-operator-md.ts +++ b/js/web/script/generate-webgpu-operator-md.ts @@ -16,6 +16,8 @@ const COMMENTS: Record = { 'Reshape': 'no GPU kernel', 'Shape': 'no GPU kernel; an ORT warning is generated - need to fix', 'Resize': 'CoordinateTransformMode align_corners is not supported with downsampling', + 'Attention': 'need implementing mask and past/present', + 'MultiHeadAttention': 'need implementing mask and past/present', }; /* eslint-disable max-len */ diff --git a/js/web/test/data/ops/attention.jsonc b/js/web/test/data/ops/attention.jsonc new file mode 100644 index 0000000000000..bd4483027cc25 --- /dev/null +++ b/js/web/test/data/ops/attention.jsonc @@ -0,0 +1,557 @@ +[ + { + "name": "Attention Basic", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [4, 3], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [213, 213], + "dims": [1, 2, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic Batch 2 with 2 heads", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16 + ], + "dims": [2, 2, 8], + "type": "float32" + }, + { + "data": [ + 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 + ], + "dims": [8, 6], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [6], + "type": "float32" + } + ], + "outputs": [ + { + "data": [320, 321, 320, 321, 320, 321, 320, 321], + "dims": [2, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863], + "dims": [1, 3, 2], + "type": "float32" + }, + { + "data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1.1103, -1.6898, -0.989], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-1.328187108039856, -1.297916054725647, -0.8599594831466675], + "dims": [1, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic one head, batch 2", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3, 2], + "type": "float32" + }, + { + "data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1.1103, -1.6898, -0.989], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987 + ], + "dims": [2, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 2 head, batch 1", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3, 2], + "type": "float32" + }, + { + "data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643], + "dims": [2, 6], + "type": "float32" + }, + { + "data": [1.1103, -1.6898, -0.989, -0.989, 1.1103, -1.6898], + "dims": [6], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.8701779842376709, -2.6158859729766846, 0.8710794448852539, -2.5763747692108154, 0.9005484580993652, + -2.182751178741455, 2.1661579608917236, -2.1045265197753906, 1.6716957092285156, -1.797281265258789, + 1.7134947776794434, -1.765358328819275 + ], + "dims": [2, 3, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 2", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987 + ], + "dims": [2, 3, 5], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236 + ], + "dims": [5, 15], + "type": "float32" + }, + { + "data": [ + 1.1103, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, -1.6898, -0.989, -1.9029953479766846, 0.8710794448852539, + -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, 1.7134947776794434 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.6956915855407715, -2.8863370418548584, 1.3899128437042236, 1.6789076328277588, -1.4083852767944336, + -1.7009180784225464, -3.1053788661956787, 3.5959298610687256, 1.1027096509933472, -0.009643087163567543, + -1.694351315498352, -2.9284396171569824, 1.734721302986145, 2.0606398582458496, -0.2571452260017395, + 3.671973943710327, -5.285338401794434, -6.833454132080078, 1.7506506443023682, -2.262148380279541, + 2.5110034942626953, 1.440049171447754, -0.9423203468322754, 1.7506506443023682, -1.86212158203125, + -0.5036701560020447, -5.732386589050293, -1.5674757957458496, 1.7506510019302368, -2.264472246170044 + ], + "dims": [2, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 1", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846 + ], + "dims": [1, 3, 5], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236 + ], + "dims": [5, 15], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [1, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987 + ], + "dims": [3, 3, 5], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236 + ], + "dims": [5, 15], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326, + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326, + 3.7965505123138428, -2.3799397945404053, -3.9530906677246094, 0.5844926834106445, -2.9756431579589844, + 2.448162794113159, 4.34546422958374, 1.9380426406860352, 0.5870105624198914, -2.7368364334106445, + -0.4769568145275116, 4.255186557769775, -3.9529950618743896, 0.6987408995628357, -2.9756433963775635 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, + 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, + 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987 + ], + "dims": [3, 3, 10], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, + 0.2303, 0.4617, 1.44, -2.22, 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, + -1.9054111242294312, 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236 + ], + "dims": [10, 15], + "type": "float32" + }, + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -8.01101303100586, -5.782258987426758, 6.016238689422607, 0.26747000217437744, -6.992541313171387, + -8.011263847351074, -5.782248020172119, 5.366001129150391, 0.26747000217437744, -6.99449348449707, + -8.011263847351074, -5.782265663146973, 6.016238689422607, 0.26747000217437744, -6.992537021636963, + -6.102723598480225, -7.28973388671875, -4.578637599945068, 7.2203369140625, -6.028444766998291, + -6.102705478668213, -7.2897748947143555, -3.7882626056671143, 5.393260478973389, -5.754333972930908, + -1.3616288900375366, -7.289827823638916, -6.341128349304199, 6.329389572143555, -5.751791954040527, + -2.3945987224578857, -14.532954216003418, 3.969801902770996, 12.744998931884766, -11.1966552734375, + -2.4002532958984375, -14.538958549499512, -6.684961318969727, 12.476543426513672, -9.24352741241455, + -4.787771701812744, -8.640848159790039, 3.969801902770996, -0.6471102833747864, -11.1966552734375 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 1 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, + 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, + 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987 + ], + "dims": [3, 3, 10], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, + 0.2303, 0.4617, 1.44, -2.22, 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, + -1.9054111242294312, 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236 + ], + "dims": [10, 15], + "type": "float32" + }, + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + 1.3541864156723022, -7.813620090484619, -6.758509635925293, 7.597365856170654, -13.926229476928711, + -1.322464108467102, -7.297357559204102, -0.05962071940302849, 6.347561836242676, -5.869992256164551, + -1.3616288900375366, -7.28973388671875, 0.0386197566986084, 6.329389572143555, -5.751791954040527, + -2.400698661804199, -14.538958549499512, -7.898950576782227, 12.744998931884766, -11.1966552734375, + -2.400698661804199, -14.538958549499512, -7.898950576782227, 12.744998931884766, -11.1966552734375, + 1.021930456161499, -2.373898983001709, 3.8501391410827637, -0.6108309626579285, -9.256340980529785 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/multi-head-attention.jsonc b/js/web/test/data/ops/multi-head-attention.jsonc new file mode 100644 index 0000000000000..05687bd482e24 --- /dev/null +++ b/js/web/test/data/ops/multi-head-attention.jsonc @@ -0,0 +1,194 @@ +[ + { + "name": "MultiHeadAttention Basic, one head", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 4.973228454589844, 5.973228454589844, 6.973228454589844, 7.973228454589844, 4.999990940093994, + 5.999990940093994, 6.999990940093994, 7.999990940093994 + ], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 4.571832656860352, 5.571832656860352, 6.971858501434326, 7.971858501434326, 4.998325824737549, + 5.998325824737549, 6.999900817871094, 7.999900817871094 + ], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic with bias", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4], + "dims": [12], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 5.943336009979248, 7.94333553314209, 9.999799728393555, 11.999798774719238, 5.9997992515563965, + 7.9997992515563965, 10, 11.999999046325684 + ], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention two heads", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 8.99963665008545, 9.99963665008545, 10.99963665008545, 11.999635696411133, 13, 14, 15, 16, 9, 10, 11, 12, + 13, 14, 15, 16 + ], + "dims": [1, 2, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention two heads", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[1]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 1, 8], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 1, 8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 8], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/slice.jsonc b/js/web/test/data/ops/slice.jsonc index 9c90817a80c36..beef154a29932 100644 --- a/js/web/test/data/ops/slice.jsonc +++ b/js/web/test/data/ops/slice.jsonc @@ -21,6 +21,29 @@ } ] }, + { + "name": "Slice float32 with input[0] dim > 4", + "operator": "Slice", + "attributes": [], + "cases": [ + { + "name": "T[1, 1, 1, 1, 5] T[1] T[1] T[1] (float32)", + "inputs": [ + { + "data": [ + 0.3964604139328003, -0.8916832804679871, -1.6578896045684814, 1.960708737373352, 1.181204915046692 + ], + "dims": [1, 1, 1, 1, 5], + "type": "float32" + }, + { "data": [3], "dims": [1], "type": "int64" }, + { "data": [4], "dims": [1], "type": "int64" }, + { "data": [4], "dims": [1], "type": "int64" } + ], + "outputs": [{ "data": [1.960708737373352], "dims": [1, 1, 1, 1, 1], "type": "float32" }] + } + ] + }, { "name": "Slice int32", "operator": "Slice", diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index c80f0b04a9abc..37aa9394c7f96 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1336,6 +1336,7 @@ "add_int32.jsonc", //"and.jsonc", "asin.jsonc", + "attention.jsonc", "bias-add.jsonc", "bias-split-gelu.jsonc", "ceil.jsonc", @@ -1362,6 +1363,7 @@ "matmul-broadcast.jsonc", "mul.jsonc", "mul_int32.jsonc", + "multi-head-attention.jsonc", //"neg.jsonc", "neg-int32.jsonc", "not.jsonc", diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 0ed7d887fc5e5..57219c50f39aa 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -61,7 +61,6 @@ from onnxruntime.capi.onnxruntime_inference_collection import OrtDevice # noqa: F401 from onnxruntime.capi.onnxruntime_inference_collection import OrtValue # noqa: F401 from onnxruntime.capi.onnxruntime_inference_collection import SparseTensor # noqa: F401 -from onnxruntime.capi.training import * # noqa: F403 # TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end try: # noqa: SIM105 diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index b693b58c7c40a..a7f83469a768d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -96,9 +96,9 @@ struct GroupQueryAttentionParameters { int kv_num_heads; int num_splits; // number of splits for splitkv bool is_unidirectional; // causal + int local_window_size; bool kv_share_buffer; - bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor - bool left_padding; // copies last token to last index if true + bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor float scale; AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index 4a266af789250..47f462d75fcc4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -63,6 +63,16 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { const int head_size = parameters.head_size; const int position_ids_format = parameters.position_ids_format; const int half_head_size = head_size / 2; + // Default input tensor shape is [batch, seq_len, hidden_size] + int head_stride = head_size; + int seq_stride = num_heads * head_stride; + int batch_stride = sequence_length * seq_stride; + if (parameters.transposed) { + // Transposed input tensor shape is [batch, num_heads, seq_len, head_size] + seq_stride = head_size; + head_stride = sequence_length * seq_stride; + batch_stride = num_heads * head_stride; + } AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -76,11 +86,10 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { const int s = static_cast((ptr / num_heads) % sequence_length); const int n = static_cast(ptr % num_heads); - const int block_offset = b * sequence_length * num_heads + s * num_heads + n; - const int data_offset = block_offset * head_size; + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - const T* input_data = input_src + data_offset; - T* output_data = output_dest + data_offset; + const T* input_data = input_src + block_offset; + T* output_data = output_dest + block_offset; // Cache is (M, H/2) const int position_id = (position_ids_format == 0) diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h index cf8080800e072..7b2e8289f7b06 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h @@ -18,6 +18,7 @@ struct RotaryParameters { int num_heads; // num_heads = hidden_size / head_size int max_sequence_length; // Sequence length used by cos/sin cache int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) + bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) }; template @@ -33,8 +34,8 @@ Status CheckInputs(const T* input, // Check input const auto& input_dims = input->Shape().GetDims(); - if (input_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 dimensions, got ", + if (input_dims.size() != 3 && input_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 or 4 dimensions, got ", input_dims.size()); } // Check position_ids @@ -63,6 +64,14 @@ Status CheckInputs(const T* input, int batch_size = static_cast(input_dims[0]); int sequence_length = static_cast(input_dims[1]); int hidden_size = static_cast(input_dims[2]); + + bool transposed = false; + if (input_dims.size() == 4) { + // input is [batch, num_heads, seq, head_size] + sequence_length = static_cast(input_dims[2]); + hidden_size = static_cast(input_dims[1]) * static_cast(input_dims[3]); + transposed = true; + } int max_sequence_length = static_cast(cos_cache_dims[0]); int head_size = static_cast(cos_cache_dims[1]) * 2; int num_heads = hidden_size / head_size; @@ -111,6 +120,7 @@ Status CheckInputs(const T* input, output_parameters->num_heads = num_heads; output_parameters->max_sequence_length = max_sequence_length; output_parameters->position_ids_format = position_ids_format; + output_parameters->transposed = transposed; } return Status::OK(); @@ -118,4 +128,4 @@ Status CheckInputs(const T* input, } // namespace rotary_embedding_helper } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h index 89e2351428d40..cbe536c6ce45a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -69,6 +69,7 @@ struct Flash_fwd_params : public Qkv_params { int seqlen_q_rounded = 0; int seqlen_k_rounded = 0; int d_rounded = 0; + int rotary_dim = 0; // The scaling factors for the kernel. float scale_softmax = 0.0; @@ -92,12 +93,26 @@ struct Flash_fwd_params : public Qkv_params { index_t knew_head_stride = 0; index_t vnew_head_stride = 0; + // The cos and sin matrices for rotary embedding. + void* __restrict__ rotary_cos_ptr = nullptr; + void* __restrict__ rotary_sin_ptr = nullptr; + + // The indices to index into the KV cache. + int* __restrict__ cache_batch_idx = nullptr; + + // Local window size + int window_size_left = -1; + int window_size_right = -1; + bool is_bf16 = false; bool is_causal = false; // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. bool is_seqlens_k_cumulative = true; + + bool is_rotary_interleaved = false; + int num_splits = 0; // For split-KV version const cudaDeviceProp* dprops = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 89a27c4d2b0d3..76190aad68fdb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -35,7 +35,9 @@ void set_params_fprop(Flash_fwd_params& params, void* softmax_lse_d, float softmax_scale, bool is_causal, - bool kv_bsnh = true) { + bool kv_bsnh = true, + int window_size_left = -1, + int window_size_right = -1) { // Set the pointers and strides. params.q_ptr = q; params.k_ptr = k; @@ -102,7 +104,21 @@ void set_params_fprop(Flash_fwd_params& params, params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; + // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates + // local and causal, meaning when we have local window size params.is_causal = is_causal; + if (is_causal && (window_size_left >= 0 || window_size_right != 0)) { + params.is_causal = false; + } + if (window_size_left < 0 && window_size_right >= 0) { + window_size_left = seqlen_k; + } + if (window_size_left >= 0 && window_size_right < 0) { + window_size_right = seqlen_k; + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.is_seqlens_k_cumulative = true; } @@ -227,7 +243,8 @@ Status mha_fwd(const cudaDeviceProp& dprops, int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - bool kv_bsnh) { + bool kv_bsnh, + int local_window_size) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); @@ -247,7 +264,9 @@ Status mha_fwd(const cudaDeviceProp& dprops, softmax_lse, softmax_scale, is_causal, - kv_bsnh); + kv_bsnh, + local_window_size, + is_causal ? 0 : -1); params.dprops = &dprops; params.knew_ptr = nullptr; params.vnew_ptr = nullptr; @@ -306,7 +325,10 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, nullptr, softmax_lse, softmax_scale, - is_causal); + is_causal, + true, + -1, + is_causal ? 0 : -1); params.dprops = &dprops; params.num_splits = 0; params.softmax_lseaccum_ptr = nullptr; @@ -347,11 +369,11 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, bool past_bsnh, // otherwise bnsh int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads - void* out_accum // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded -) { - if (seqlen_q == 1) { - is_causal = false; - } // causal=true is the same as causal=false in this case + void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int local_window_size) { + // if (seqlen_q == 1) { + // is_causal = false; + // } // causal=true is the same as causal=false in this case auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); @@ -372,7 +394,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, softmax_lse, softmax_scale, is_causal, - past_bsnh); + past_bsnh, + local_window_size, + is_causal ? 0 : -1); params.dprops = &dprops; if (k != nullptr && v != nullptr) { diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 58f4304251872..efc1f565c4fa0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -54,7 +54,8 @@ Status mha_fwd(const cudaDeviceProp& dprops, int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - bool kv_bsnh = true); + bool kv_bsnh = true, + int local_window_size = -1); Status mha_varlen_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, @@ -96,8 +97,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, bool past_bsnh, // otherwise bnsh int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads - void* out_accum = nullptr // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded -); + void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int local_window_size = -1); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index eb1c794d6df54..028233f66850f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -29,47 +29,6 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -CUTE_HOST_DEVICE auto -make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(cute::Layout, cute::Int>, - cute::Stride<_1, cute::Int>>{}, - make_layout(cute::size<2>(TileShape_MNK{}))); - - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTE_HOST_DEVICE auto -make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) { - using TileShape_MNK = typename TiledMMA::TiledShape_MNK; - using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; - constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value; - constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M; - constexpr int MMAStride_M = MMA_M * AtomShape_M; - auto t = make_tile(cute::Layout, cute::Int>, - cute::Stride<_1, cute::Int>>{}, - // TODO: Shouldn't this be size<1>? - make_layout(cute::size<2>(TileShape_MNK{}))); - // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); } - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, Tensor1& scores_sum, Tensor2& acc_o, float softmax_scale_log2) { @@ -123,7 +82,7 @@ inline __device__ void write_softmax_to_gmem( //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; @@ -144,12 +103,14 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); - if (Is_causal) { - n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); // We exit early and write 0 to gO and gLSE. // Otherwise we might read OOB elements from gK and gV. - if (n_block_max <= 0) { + if (n_block_max <= n_block_min) { const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), @@ -197,7 +158,6 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - cute::Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), cute::Shape, cute::Int>{}, make_stride(params.q_row_stride, _1{})); @@ -332,9 +292,9 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = !Is_causal + constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 - : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) @@ -364,22 +324,22 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // We don't put the masking before the matmul S = Q K^T because we don't clear sK // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // can produce Inf / NaN. - if (!Is_causal) { + if (!Is_causal && !Is_local) { if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } } else { // I can't get the stride from idx_row - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM + get<0>(idx_row(0)), - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, - kNWarps * 16); + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); } flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); @@ -390,8 +350,8 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); // Convert scores from fp32 to fp16/bf16 cute::Tensor rP = flash::convert_type(scores); @@ -408,14 +368,14 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // This check is at the end of the loop since we always have at least 1 iteration - if (n_masking_steps > 1 && n_block <= 0) { + if (n_masking_steps > 1 && n_block <= n_block_min) { --n_block; break; } } // These are the iterations where we don't need masking on S - for (; n_block >= 0; --n_block) { + for (; n_block >= n_block_min; --n_block) { cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); flash::cp_async_wait<0>(); @@ -431,7 +391,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi flash::cp_async_wait<0>(); __syncthreads(); - if (n_block > 0) { + if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); @@ -441,8 +401,15 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi } // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); cute::Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) @@ -543,7 +510,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; @@ -572,11 +539,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons if (m_block * kBlockM >= binfo.actual_seqlen_q) return; const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; - const int n_block_min = n_split_idx * n_blocks_per_split; + const int n_block_min = !Is_local + ? n_split_idx * n_blocks_per_split + : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); - if (Is_causal) { + if (Is_causal || Is_local) { n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); } if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 // We exit early and write 0 to gOaccum and -inf to gLSEaccum. @@ -626,10 +595,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; // We move K and V to the last block. - const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; - const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, @@ -641,16 +609,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), Shape, Int>{}, make_stride(params.v_row_stride, _1{})); - // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, - // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. - // This maps to accessing the first 64 rows of knew_ptr. - Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), - Shape, Int>{}, - make_stride(params.knew_row_stride, _1{})); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } - Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), - Shape, Int>{}, - make_stride(params.vnew_row_stride, _1{})); Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQ{}); @@ -664,11 +622,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); typename Kernel_traits::TiledMma tiled_mma; @@ -732,17 +688,129 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons } // Prologue + // Copy from Knew to K, optionally apply rotary embedding. + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + if constexpr (Append_KV) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } + // if (cute::thread(8, 0)) { print_tensor(gCos); } + // if (cute::thread(0, 0)) { print_tensor(tRgCos); } + + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { + flash::copy_w_min_idx( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + if (params.rotary_dim == 0) { + flash::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } else { + if (params.is_rotary_interleaved) { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_interleaved( + tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); + tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); + } else { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_contiguous( + tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); + tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + } + } + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + } + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + if (n_block_max > n_block_copy_min) { + tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride; + tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride; + } + } + // Read Q from gmem to smem, optionally apply rotary embedding. Tensor tQrQ = make_fragment_like(tQgQ); - // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); + if (!Append_KV || params.rotary_dim == 0) { + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + } else { + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. + // We do this by setting the row stride of gCos / gSin to 0. + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + if (params.is_rotary_interleaved) { + flash::copy_rotary_interleaved( + tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim); + } else { + flash::copy_rotary_contiguous( + tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim); + } + } int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy_2_sources( - gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); // flash::cp_async_wait<0>(); @@ -760,9 +828,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. - constexpr int n_masking_steps = !Is_causal + constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 - : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) @@ -770,32 +838,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons flash::cp_async_wait<0>(); __syncthreads(); - if constexpr (Append_KV) { - // if (cute::thread0()) { print(tKgK); } - // if (cute::thread0()) { print(tKsK); } - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - // __syncthreads(); - // if (cute::thread0()) { print(tKgK); } - // __syncthreads(); - } - // Advance gV if (masking_step > 0) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - if (Append_KV) { - tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); - } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads - flash::copy_2_sources( - gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); } cute::cp_async_fence(); @@ -810,15 +860,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // We don't put the masking before the matmul S = Q K^T because we don't clear sK // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // can produce Inf / NaN. - if (!Is_causal) { + if (!Is_causal && !Is_local) { if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } } else { - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, - kNWarps * 16); + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); } flash::cp_async_wait<0>(); @@ -826,26 +876,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } // __syncthreads(); - // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); } - if constexpr (Append_KV) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - } - if (n_block > n_block_min) { // Advance gK - // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); } tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - if (Append_KV) { - tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); - } - // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0, - binfo.seqlen_k_cache - (n_block - 1) * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -853,8 +887,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // We have key_padding_mask so we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } // Convert scores from fp32 to fp16/bf16 @@ -879,20 +913,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons clear(acc_s); flash::cp_async_wait<0>(); __syncthreads(); - if constexpr (Append_KV) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - } // Advance gV tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - if (Append_KV) { - tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); - } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); flash::gemm( @@ -901,22 +924,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons flash::cp_async_wait<0>(); __syncthreads(); - if constexpr (Append_KV) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); } - if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) { - flash::copy_w_min_idx( - tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); - } - } if (n_block > n_block_min) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - if (Append_KV) { - tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); - } - flash::copy_2_sources( - gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0, - binfo.seqlen_k_cache - (n_block - 1) * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -924,7 +935,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) @@ -1031,7 +1049,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1047,12 +1065,12 @@ inline __device__ void compute_attn(const Params& params) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_splitkv(const Params& params) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1061,24 +1079,23 @@ inline __device__ void compute_attn_splitkv(const Params& params) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void combine_attn_seqk_parallel(const Params& params) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; constexpr int kMaxSplits = 1 << Log_max_splits; - constexpr int kBlockM = 16; constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNThreads = Kernel_traits::kNThreads; static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); - // static_assert(kMaxSplits <= 8, "kMaxSplits must be <= 8 for now, will extend layer"); - static_assert(kBlockM == 16 || kBlockM == 32, "kBlockM must be 16 or 32"); - static_assert(Kernel_traits::kNThreads == 128, "We assume that each block has 128 threads"); + static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); // Shared memory. // kBlockM + 1 instead of kBlockM to reduce bank conflicts. @@ -1094,10 +1111,10 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { make_stride(params.b * params.h * params.seqlen_q, _1{})); Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); - constexpr int kNLsePerThread = (kMaxSplits * kBlockM + Kernel_traits::kNThreads - 1) / Kernel_traits::kNThreads; + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; // Read the LSE values from gmem and store them in shared memory, then tranpose them. - constexpr int kRowsPerLoadLSE = Kernel_traits::kNThreads / kBlockM; + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; #pragma unroll for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadLSE + tidx / kBlockM; @@ -1165,7 +1182,12 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), Shape, Int>{}, Stride, _1>{}); - typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); Tensor tOrO = make_tensor(shape(tOgOaccum)); @@ -1183,8 +1205,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } } -// Load Oaccum in then scale and accumulate to O -#pragma unroll 2 + // Load Oaccum in then scale and accumulate to O for (int split = 0; split < params.num_splits; ++split) { flash::copy( gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index 82dfa59b8f8e7..87d189a803f8a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -10,29 +10,30 @@ namespace onnxruntime { namespace flash { -template +template __global__ void flash_fwd_kernel(Flash_fwd_params params) { + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - flash::compute_attn(params); + flash::compute_attn(params); #else (void)params; #endif } -template +template __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - flash::compute_attn_splitkv(params); + flash::compute_attn_splitkv(params); #else (void)params; #endif } -template +template __global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static_assert(Log_max_splits >= 1); - flash::combine_attn_seqk_parallel(params); + flash::combine_attn_seqk_parallel(params); #else (void)params; #endif @@ -52,20 +53,25 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { const bool is_even_K = params.d == Kernel_traits::kHeadDim; BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - // Will only return softmax if dropout, to reduce compilation time. - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - // ORT_ENFORCE(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + // ORT_ENFORCE(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + }); }); }); } @@ -82,40 +88,46 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH(params.num_splits > 1, Split, [&] { - BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { - // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); - auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV, IsEvenKConst, Split, Append_KV > ; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - } - kernel<<>>(params); + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); + auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<>>(params); + }); }); }); }); }); }); if (params.num_splits > 1) { - dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16); + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { if (params.num_splits <= 2) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 4) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 8) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 16) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 32) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 64) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 128) { - flash_fwd_splitkv_combine_kernel<<>>(params); + flash_fwd_splitkv_combine_kernel<<>>(params); } }); } @@ -130,7 +142,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) template void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 32; + constexpr static int Headdim = 32; BOOL_SWITCH(params.is_causal, Is_causal, [&] { run_flash_fwd, Is_causal>(params, stream); }); @@ -138,7 +150,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 64; + constexpr static int Headdim = 64; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower // Using block size (64 x 256) is 27% slower for seqlen=2k @@ -174,8 +186,8 @@ void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 128; - const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + constexpr static int Headdim = 128; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. @@ -201,8 +213,8 @@ void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) { - constexpr int Headdim = 160; - const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + constexpr static int Headdim = 160; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For A100, H100, 128 x 32 is the fastest. // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), @@ -241,12 +253,11 @@ void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { - constexpr size_t Headdim = 224; - constexpr size_t threshold = 2 * Headdim * (128 + 2 * 64); - size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; + constexpr static int Headdim = 224; + int max_smem_per_block = params.dprops->sharedMemPerBlockOptin; // printf("max_smem_per_block = %d\n", max_smem_per_block); BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if (max_smem_per_block >= threshold) { // 112 KB + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB run_flash_fwd, Is_causal>(params, stream); } else { run_flash_fwd, Is_causal>(params, stream); @@ -262,16 +273,14 @@ void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) { - constexpr size_t Headdim = 256; - constexpr size_t min_threshold = 2 * Headdim * (128 + 2 * 64); - constexpr size_t max_threshold = 4 * Headdim * (64 + 2 * 64); + constexpr static int Headdim = 256; size_t max_smem_per_sm = params.dprops->sharedMemPerMultiprocessor; size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For A100, we want to run with 128 x 64 (128KB smem). // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. - if (max_smem_per_block >= min_threshold && max_smem_per_sm < max_threshold) { + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { run_flash_fwd, Is_causal>(params, stream); } else { run_flash_fwd, Is_causal>(params, stream); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h index 134f159e258c4..1c0ed7f2fc2e8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h @@ -161,7 +161,14 @@ struct Flash_fwd_kernel_traits : public Base { cute::Stride<_16, _1>>>; using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, GmemLayoutAtomOaccum{}, - cute::Layout>{})); // Val layout, 4 vals per store + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype(make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load }; // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h index 842edf3a98a86..8017f83bbb01d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h @@ -139,10 +139,11 @@ inline __device__ void apply_mask(Tensor& tensor, const int max_ } } -template -inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, - const int max_seqlen_k, const int row_idx_offset_, - const int max_seqlen_q, const int warp_row_stride) { +template +inline __device__ void apply_mask_local(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; @@ -155,14 +156,15 @@ inline __device__ void apply_mask_causal(Tensor& tensor, const i #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; - if (col_idx >= col_idx_limit) { + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } @@ -176,6 +178,15 @@ inline __device__ void apply_mask_causal(Tensor& tensor, const i } } +template +inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_, + max_seqlen_q, warp_row_stride, -1, 0); +} + template inline __device__ void apply_mask_causal_w_idx( Tensor& tensor, Tensor const& idx_rowcol, diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h index 02042e183f808..271112c5e890a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h @@ -307,7 +307,7 @@ template inline __device__ void copy(TiledCopy tiled_copy, Tensor const& S, Tensor& D, Tensor const& identity_MN, - Tensor const& predicate_K, int max_MN = 0) { + Tensor const& predicate_K, const int max_MN = 0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA @@ -334,65 +334,161 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor const //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void copy_2_sources(TiledCopy tiled_copy, Tensor const& S0, - Tensor const& S1, +inline __device__ void copy_w_min_idx(Tensor const& S, Tensor& D, Tensor const& identity_MN, Tensor const& predicate_K, - const int max_MN = 0, const int row_idx_switch = 0) { - CUTE_STATIC_ASSERT_V(rank(S0) == Int<3>{} && rank(S1) == Int<3>{}); + const int max_MN = 0, const int min_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S0) == size<0>(D) && size<0>(S1) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S0) == size<1>(D) && size<1>(S1) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S0) == size<2>(D) && size<2>(S1) == size<2>(D)); // MMA_K - // There's no case where !Clear_OOB_K && Clear_OOB_MN - static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); -// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); } -// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); } + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } #pragma unroll - for (int m = 0; m < size<1>(S0); ++m) { - auto& S = !Is_2_sources || get<0>(identity_MN(0, m, 0)) < row_idx_switch ? S0 : S1; - if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } #pragma unroll - for (int k = 0; k < size<2>(S0); ++k) { + for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { - cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + cute::copy(S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_rotary_interleaved(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K + static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + cute::copy(Cos(_, m, k), rCos(_, m, k)); + cute::copy(Sin(_, m, k), rSin(_, m, k)); + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS) / 2; ++i) { + float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); + float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); + S_fp32(2 * i) = real; + S_fp32(2 * i + 1) = imag; + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + } + cute::copy(rS(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { cute::clear(D(_, m, k)); } } - } else if (Clear_OOB_MN) { - cute::clear(D(_, m, _)); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void copy_w_min_idx(Tensor const& S, - Tensor& D, Tensor const& identity_MN, - Tensor const& predicate_K, - const int max_MN = 0, const int min_MN = 0) { +inline __device__ void copy_rotary_contiguous(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K -// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + Tensor rS_other = make_fragment_like(rS(_, 0, 0)); #pragma unroll for (int m = 0; m < size<1>(S); ++m) { - // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { -// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } #pragma unroll for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || predicate_K(k)) { - cute::copy(S(_, m, k), D(_, m, k)); + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; + Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); + cute::copy(gS_other, rS_other); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } + Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); + Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); + cute::copy(gCos, rCos(_, m, k)); + cute::copy(gSin, rSin(_, m, k)); + // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor S_other_fp32 = convert_type(rS_other); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS); ++i) { + S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); } + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); } } } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index f21dff08e0350..93892169f6c79 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -44,9 +44,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); num_heads_ = static_cast(num_heads); kv_num_heads_ = static_cast(kv_num_heads); - is_unidirectional_ = true; - // left_padding_ = info.GetAttrOrDefault("left_padding_last_token", 0) == 1; is_past_bsnh_ = false; // info.GetAttrOrDefault("is_past_bsnh", 1) == 1; + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); scale_ = info.GetAttrOrDefault("scale", 0.0f); #if USE_FLASH_ATTENTION @@ -92,8 +91,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { is_past_bsnh_, scale_, device_prop.maxThreadsPerBlock)); - parameters.is_unidirectional = is_unidirectional_; - // parameters.left_padding = left_padding_; + parameters.local_window_size = local_window_size_; int sequence_length = parameters.sequence_length; TensorShapeVector output_shape(3); @@ -139,6 +137,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { bool use_memory_efficient_attention = !use_flash_attention && !disable_memory_efficient_attention_ && + local_window_size_ == -1 && (parameters.head_size & 7) == 0 && parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && @@ -222,6 +221,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.k = reinterpret_cast(k_buffer.get()); data.v = reinterpret_cast(v_buffer.get()); } + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (fmha_buffer != nullptr) { + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index aade0436dc141..54a8127e29e7b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -22,8 +22,7 @@ class GroupQueryAttention final : public CudaKernel { protected: int num_heads_; // number of attention heads int kv_num_heads_; // different for k and v for group query attention - // bool left_padding_; // shifts last token to end of buffer - bool is_unidirectional_; // causal + int local_window_size_; bool is_past_bsnh_; float scale_; bool disable_flash_attention_; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 2d158155eeba9..b22ccb68c1e7b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -468,55 +468,6 @@ Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, i return CUDA_CALL(cudaGetLastError()); } -// // Kernel to append new kv to kv buffer in place -// template -// __global__ void LeftPadLast(const int max_seqlen, -// T* kv_buff, -// const int* seqlens_k) { // refers to kv buff; otherwise bnsh -// const int h = threadIdx.x; -// const int n = blockIdx.x; -// const int b = blockIdx.y; - -// const int num_heads = gridDim.x; -// const int H = blockDim.x; - -// const int present_batch_stride = max_seqlen * num_heads * H; -// const int present_row_stride = num_heads * H; -// const int present_head_stride = H; - -// // kv_buff: BTNH or BNTH with buffered memory for new -// // new_kv: BLNH - -// const int s = seqlens_k[b]; - -// const int in_offset = b * present_batch_stride + s * present_row_stride + n * present_head_stride + h; -// const int out_offset = b * present_batch_stride + (max_seqlen - 1) * present_row_stride + n * present_head_stride + h; -// kv_buff[out_offset] = kv_buff[in_offset]; -// } - -// // Concat new to kv buffer in place -// template -// Status LaunchLeftPadLast(contrib::GroupQueryAttentionParameters& parameters, -// GroupQueryAttentionData& data, -// cudaStream_t stream, -// const int max_threads_per_block) { -// const int batch_size = parameters.batch_size; -// const int sequence_length = parameters.sequence_length; -// const int num_heads = parameters.num_heads; -// const int head_size = parameters.head_size; - -// // Indicates past sequence_length of each sequence -// const int* seqlens_k = reinterpret_cast(data.seqlens_k); - -// const int H = head_size / 4; -// const dim3 grid(num_heads, batch_size, 1); -// const dim3 block(H, 1, 1); -// LeftPadLast<<>>(sequence_length, -// reinterpret_cast(data.output), -// seqlens_k); -// return CUDA_CALL(cudaGetLastError()); -// } - ////////// Launch Kernels #if USE_FLASH_ATTENTION @@ -541,7 +492,7 @@ Status FlashAttention( void* key = reinterpret_cast(const_cast(data.key)); void* value = reinterpret_cast(const_cast(data.value)); - bool is_causal = parameters.is_unidirectional; + bool is_causal = true; // Note: seqlens_k is past sequence length for flash if (parameters.is_prompt) { @@ -579,7 +530,7 @@ Status FlashAttention( seqlens_k, batch_size, num_heads, kv_num_heads, head_size, sequence_length, present_sequence_length, kv_sequence_length, scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum))); + reinterpret_cast(data.out_accum), parameters.local_window_size)); } else { // Not share buffer case // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient @@ -611,13 +562,9 @@ Status FlashAttention( seqlens_k, batch_size, num_heads, kv_num_heads, head_size, sequence_length, present_sequence_length, 0, scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum))); + reinterpret_cast(data.out_accum), parameters.local_window_size)); } - // if (parameters.left_padding && parameters.is_prompt) { - // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); - // } - DUMP_TENSOR_INIT(); DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); @@ -704,9 +651,11 @@ Status EfficientAttention( p.max_sequence_length = present_sequence_length; p.qk_head_size = head_size; p.v_head_size = head_size; - p.causal = parameters.is_unidirectional; + p.causal = true; p.scale = scale; p.seqlen_k_ptr = data.seqlens_k_total; // Note: seqlens_k is total sequence length for efficient + p.seqstart_q_ptr = nullptr; + p.seqstart_k_ptr = nullptr; p.query = query; p.key = key; p.value = value; @@ -721,10 +670,6 @@ Status EfficientAttention( p.has_custom_right_padding = true; run_memory_efficient_attention(p); - // if (parameters.left_padding && parameters.is_prompt) { - // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); - // } - DUMP_TENSOR_INIT(); DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc index b4b5dac1fbe19..2d12e975d88d7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -74,7 +74,8 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { parameters.max_sequence_length, parameters.position_ids_format, interleaved, - device_prop.maxThreadsPerBlock); + device_prop.maxThreadsPerBlock, + parameters.transposed); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index c54e72dcfce13..e1b83bd8caf54 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -27,7 +27,10 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int num_heads, const int head_size, const int position_ids_format, - const bool interleaved) { + const bool interleaved, + const int batch_stride, + const int seq_stride, + const int head_stride) { // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length // Use .x in innermost loop to access global memory efficiently @@ -37,11 +40,10 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int i = threadIdx.x; - const int block_offset = b * sequence_length * num_heads + s * num_heads + n; - const int data_offset = block_offset * head_size; + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - const T* input_data = input + data_offset; - T* output_data = output + data_offset; + const T* input_data = input + block_offset; + T* output_data = output + block_offset; // Cache is (M, H/2) const int half_head_size = head_size / 2; @@ -83,7 +85,8 @@ Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block) { + const int max_threads_per_block, + const bool transposed) { constexpr int smem_size = 0; const dim3 grid(num_heads, sequence_length, batch_size); @@ -94,10 +97,22 @@ Status LaunchRotaryEmbeddingKernel( // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` // instead. This will require kernel changes to support. + // Default input tensor shape is [batch, seq, hidden_size] + int head_stride = head_size; + int seq_stride = num_heads * head_stride; + int batch_stride = sequence_length * seq_stride; + if (transposed) { + // When transposed, input tensor shape is [batch, num_heads, seq, head_size] + seq_stride = head_size; + head_stride = sequence_length * seq_stride; + batch_stride = num_heads * head_stride; + } + assert(head_size <= max_threads_per_block); RotaryEmbeddingBSNH<<>>( output, input, cos_cache, sin_cache, position_ids, - sequence_length, num_heads, head_size, position_ids_format, interleaved + sequence_length, num_heads, head_size, position_ids_format, interleaved, + batch_stride, seq_stride, head_stride ); return CUDA_CALL(cudaGetLastError()); @@ -117,7 +132,8 @@ template Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block); + const int max_threads_per_block, + const bool transposed); template Status LaunchRotaryEmbeddingKernel( cudaStream_t stream, @@ -133,7 +149,8 @@ template Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block); + const int max_threads_per_block, + const bool transposed); } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h index 29ff48a8ad0fb..ee1ccc43dcbff 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -24,7 +24,8 @@ Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block); + const int max_threads_per_block, + const bool transposed); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 7172a28316f16..108eea1a73fe9 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -121,6 +121,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MatMulBnb4); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu); @@ -313,6 +314,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu index e58723f0b31e1..2f74dd41f0759 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu @@ -35,6 +35,8 @@ template Status SetBnbQuantMap(int quant_type, float* quant_map_buffer, c template Status SetBnbQuantMap(int quant_type, half* quant_map_buffer, cudaStream_t stream); +template Status SetBnbQuantMap(int quant_type, BFloat16* quant_map_buffer, cudaStream_t stream); + template __global__ void kDequantizeBlockwise( const T* quant_map, @@ -62,22 +64,15 @@ __global__ void kDequantizeBlockwise( valid_items_load = (n + 1) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1) / 2 - i; valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2; - local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]); + local_abs_max = absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]; __syncthreads(); LoadChar(loadchar).Load(&(quant_data[i]), qvals, valid_items_load, 128); #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) { - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 - vals[j * 2] = quant_map[qvals[j] >> 4] * local_abs_max; - vals[j * 2 + 1] = quant_map[qvals[j] & 0x0F] * local_abs_max; - #else - // half multiplication not supported - vals[j * 2] = static_cast(static_cast(quant_map[qvals[j] >> 4]) * static_cast(local_abs_max)); - vals[j * 2 + 1] = - static_cast(static_cast(quant_map[qvals[j] & 0x0F]) * static_cast(local_abs_max)); - #endif + vals[j * 2] = ScalarMul(quant_map[qvals[j] >> 4], local_abs_max); + vals[j * 2 + 1] = ScalarMul(quant_map[qvals[j] & 0x0F], local_abs_max); } __syncthreads(); @@ -86,7 +81,7 @@ __global__ void kDequantizeBlockwise( } template -Status DequantizeBnb4( +void CallkDequantizeBlockwise( const T* quant_map, T* output, const uint8_t* quant_data, @@ -102,6 +97,18 @@ Status DequantizeBnb4( absmax, block_size / 2, numel); +} + +template +Status DequantizeBnb4( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + int block_size, + int numel, + cudaStream_t stream) { + CallkDequantizeBlockwise(quant_map, output, quant_data, absmax, block_size, numel, stream); return Status::OK(); } @@ -119,11 +126,36 @@ template Status DequantizeBnb4( const half* quant_map, half* output, const uint8_t* quant_data, - const half *absmax, + const half* absmax, int block_size, int numel, cudaStream_t stream); +template <> +Status DequantizeBnb4( + const BFloat16* quant_map, + BFloat16* output, + const uint8_t* quant_data, + const BFloat16* absmax, + int block_size, + int numel, + cudaStream_t stream) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + CallkDequantizeBlockwise( + reinterpret_cast(quant_map), + reinterpret_cast(output), + quant_data, + reinterpret_cast(absmax), + block_size, + numel, + stream); + #else + CallkDequantizeBlockwise(quant_map, output, quant_data, absmax, block_size, numel, stream); + #endif + + return Status::OK(); +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh index 4aef3ab699f9c..a0d38c9853cd6 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh @@ -11,6 +11,38 @@ namespace cuda { template Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream); +// templated scalar multiply function +template +__device__ inline T ScalarMul(T a, T b); + +template <> +__device__ inline float ScalarMul(float a, float b) { + return a * b; +} + +template <> +__device__ inline half ScalarMul(half a, half b) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + return a * b; + #else + // half multiplication not supported + return static_cast(static_cast(a) * static_cast(b)); + #endif +} + +template <> +__device__ inline BFloat16 ScalarMul(BFloat16 a, BFloat16 b) { + return a * b; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// will use the native bfloat16 multiply instruction on sm_80+ +template <> +__device__ inline nv_bfloat16 ScalarMul(nv_bfloat16 a, nv_bfloat16 b) { + return a * b; +} +#endif + template Status DequantizeBnb4( const T* quant_map, diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc index ecf332715d470..bbcb7de99781f 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -145,6 +145,17 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( .TypeConstraint("T2", DataTypeImpl::GetTensorType()), MatMulBnb4); +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + BFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu index 1d9aa75ff3701..098e3618beddd 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu @@ -6,12 +6,44 @@ #include #include #include +#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh" #include "matmul_bnb4.cuh" namespace onnxruntime { namespace contrib { namespace cuda { +template +__device__ inline float ScalarMulFloatOut(T a, T b); + +template <> +__device__ inline float ScalarMulFloatOut(float a, float b) { + return a * b; +} + +template <> +__device__ inline float ScalarMulFloatOut(half a, half b) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + return static_cast(a * b); + #else + // half multiplication not supported + return static_cast(a) * static_cast(b); + #endif +} + +template <> +__device__ inline float ScalarMulFloatOut(BFloat16 a, BFloat16 b) { + return a * b; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// will use the native bfloat16 multiply instruction on sm_80+ +template <> +__device__ inline float ScalarMulFloatOut(nv_bfloat16 a, nv_bfloat16 b) { + return static_cast(a * b); +} +#endif + #define num_values_4bit 32 template __global__ void kgemm_4bit_inference_naive( @@ -55,7 +87,7 @@ __global__ void kgemm_4bit_inference_naive( int inner_idx_halved = inner_idx / 2; int offset_B = ldb * row_B; int absidx = ((2 * offset_B) + inner_idx) / block_size; - local_absmax = __ldg(&(absmax[absidx])); + local_absmax = absmax[absidx]; if (row_B < N) { if ((inner_idx_halved + num_values_8bit) < (K / 2)) { @@ -78,18 +110,8 @@ __global__ void kgemm_4bit_inference_naive( for (int i = 0; i < 4; i++) { #pragma unroll for (int k = 0; k < num_values_8bit / 4; k++) { - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 - local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; - local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; - #else - // half multiplication not supported - local_B[k * 2] = - static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4]) * - static_cast(local_absmax)); - local_B[k * 2 + 1] = - static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F]) * - static_cast(local_absmax)); - #endif + local_B[k * 2] = ScalarMul(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4], local_absmax); + local_B[k * 2 + 1] = ScalarMul(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F], local_absmax); } if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { @@ -116,12 +138,7 @@ __global__ void kgemm_4bit_inference_naive( // accumulate in float; small performance hit for Ampere, but lower error for outputs #pragma unroll for (int k = 0; k < num_values_4bit / 4; k++) { - #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 - local_C += static_cast(local_A[k] * local_B[k]); - #else - // half multiplication not supported - local_C += static_cast(local_A[k]) * static_cast(local_B[k]); - #endif + local_C += ScalarMulFloatOut(local_A[k], local_B[k]); } } } @@ -131,8 +148,19 @@ __global__ void kgemm_4bit_inference_naive( if (row_B < N && warp_lane == 0) out[row_B] = T(local_C); } +bool CheckDims(int m, int k, int block_size) { + if (k % block_size != 0 || m > 1) { + return false; + } + // supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32] + if (block_size % 32 != 0 || block_size > 4096) { + return false; + } + return true; +} + template -bool TryMatMulBnb4( +void Callkgemm_4bit_inference_naive( const T* quant_map, T* output, const T* a_data, @@ -143,22 +171,34 @@ bool TryMatMulBnb4( int k, int block_size, cudaStream_t stream) { - if (k % block_size != 0 || m > 1) { - return false; - } - // supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32] - if (block_size % 32 != 0 || block_size > 4096) { - return false; - } - int lda = k; int ldb = (k + 1) / 2; int ldc = n; int num_blocks = (n + 3) / 4; - constexpr int bits = std::is_same_v ? 16 : 32; + constexpr int bits = std::is_same_v ? 32 : 16; kgemm_4bit_inference_naive<<>>( m, n, k, a_data, b_data_quant, absmax, quant_map, output, lda, ldb, ldc, block_size); +} + +template +bool TryMatMulBnb4( + const T* quant_map, + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream) { + if (!CheckDims(m, k, block_size)) { + return false; + } + + Callkgemm_4bit_inference_naive( + quant_map, output, a_data, b_data_quant, absmax, m, n, k, block_size, stream); return true; } @@ -187,6 +227,42 @@ template bool TryMatMulBnb4( int block_size, cudaStream_t stream); +template <> +bool TryMatMulBnb4( + const BFloat16* quant_map, + BFloat16* output, + const BFloat16* a_data, + const uint8_t* b_data_quant, + const BFloat16* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream) { + if (!CheckDims(m, k, block_size)) { + return false; + } + + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + Callkgemm_4bit_inference_naive( + reinterpret_cast(quant_map), + reinterpret_cast(output), + reinterpret_cast(a_data), + b_data_quant, + reinterpret_cast(absmax), + m, + n, + k, + block_size, + stream); + #else + Callkgemm_4bit_inference_naive( + quant_map, output, a_data, b_data_quant, absmax, m, n, k, block_size, stream); + #endif + + return true; +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/attention.cc b/onnxruntime/contrib_ops/js/bert/attention.cc new file mode 100644 index 0000000000000..723ff00aa815e --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/attention.cc @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "attention.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + Attention, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + Attention); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/attention.h b/onnxruntime/contrib_ops/js/bert/attention.h new file mode 100644 index 0000000000000..0fa823befa9b2 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/attention.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/bert/attention_base.h" +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::contrib::AttentionBase; +using onnxruntime::js::JsKernel; + +class Attention : public JsKernel, AttentionBase { + public: + explicit Attention(const OpKernelInfo& info) : JsKernel(info), AttentionBase(info, false) { + std::vector qkv_sizes(qkv_hidden_sizes_.size()); + if (qkv_hidden_sizes_.size() > 0) { + std::transform(qkv_hidden_sizes_.begin(), qkv_hidden_sizes_.end(), qkv_sizes.begin(), + [](int64_t sz) { return gsl::narrow_cast(sz); }); + } + + JSEP_INIT_KERNEL_ATTRIBUTE(Attention, ({ + "numHeads" : $1, + "isUnidirectional" : $2, + "maskFilterValue" : $3, + "scale" : $4, + "doRotary" : $5, + "qkvHiddenSizes" : $6 ? (Array.from(HEAP32.subarray(Number($7), Number($7) + $6))) : [], + "pastPresentShareBuffer" : !!$8, + }), + static_cast(num_heads_), + static_cast(is_unidirectional_), + static_cast(mask_filter_value_), + static_cast(scale_), + static_cast(do_rotary_), + static_cast(qkv_hidden_sizes_.size()), + reinterpret_cast((qkv_sizes.size() > 0) ? qkv_sizes.data() : nullptr) >> 2, + static_cast(past_present_share_buffer_)); + } +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc b/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc new file mode 100644 index 0000000000000..c43f8b7f18465 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "multi_head_attention.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + MultiHeadAttention, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + MultiHeadAttention); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/multi_head_attention.h b/onnxruntime/contrib_ops/js/bert/multi_head_attention.h new file mode 100644 index 0000000000000..6c63a2ffed4b2 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/multi_head_attention.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/bert/attention_base.h" +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::contrib::AttentionBase; +using onnxruntime::js::JsKernel; + +class MultiHeadAttention : public JsKernel, AttentionBase { + public: + explicit MultiHeadAttention(const OpKernelInfo& info) : JsKernel(info), AttentionBase(info, false) { + JSEP_INIT_KERNEL_ATTRIBUTE(MultiHeadAttention, ({ + "numHeads" : $1, + "isUnidirectional" : $2, + "maskFilterValue" : $3, + "scale" : $4, + "doRotary" : $5, + }), + static_cast(num_heads_), + static_cast(is_unidirectional_), + static_cast(mask_filter_value_), + static_cast(scale_), + static_cast(do_rotary_)); + } +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 24d327576ecd9..498a9f5679eb5 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -7,7 +7,9 @@ namespace onnxruntime { namespace contrib { namespace js { +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); @@ -21,7 +23,9 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfoName(); - if (input_map.find(name) == end_map) { + if (input_map.find(name) == input_map.cend()) { // dummy entry for an input that we didn't find a use of in the graph. log it in case that's a bug. // utils::CopyOneInputAcrossDevices will use the input OrtValue as is given we don't believe it's used anywhere. LOGS(session_state.Logger(), INFO) << (graph.IsSubgraph() ? "Subgraph" : "Graph") << " input with name " diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index dcde2ddeb8270..b97fb0d2899fc 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -991,7 +991,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( constexpr const char* GroupQueryAttention_ver1_doc = R"DOC( Group Query Self/Cross Attention. -Supports different number of heads for q and kv. +Supports different number of heads for q and kv. Only supports causal or local attention. )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( @@ -1004,10 +1004,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", AttributeProto::FLOAT, OPTIONAL_VALUE) - // .Attr("left_padding_last_token", - // "Copy last token to last index of buffer. Default is 0; 1 when true.", - // AttributeProto::INT, - // OPTIONAL_VALUE) + .Attr("local_window_size", + "left_window_size for local attention (like Mistral). Default value is -1 meaning unused.", + AttributeProto::INT, + static_cast(-1)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size)", @@ -1144,7 +1144,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OPTIONAL_VALUE) .Input(0, "input", - "3D tensor with shape (batch_size, sequence_length, hidden_size)", + "3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)", "T") .Input(1, "position_ids", @@ -1160,7 +1160,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T") .Output(0, "output", - "3D tensor with shape (batch_size, sequence_length, hidden_size)", + "tensor with same shape as input.", "T") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors") diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index db0b13b0e1d27..4c0d78f0ee297 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3431,7 +3431,7 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 .Input(1, "B", "1-dimensional quantized data for weight", "T2") .Input(2, "absmax", "quantization constants", "T1") .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") - .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float/half_float/brain_float tensors.") .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc index 094ea1e24dd92..9c98ed6d3e114 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc @@ -338,8 +338,8 @@ std::optional IsSupportedGather(Graph& graph, Node& node, auto axis = static_cast(node.GetAttributes().at("axis").i()); axis = axis < 0 ? axis + data_rank : axis; size_t dim_size = static_cast(indices_shape->dim_size()); - bool is_single_value_1d_tensor = dim_size != 0 && (dim_size == 1 && utils::HasDimValue(indices_shape->dim(0)) && - indices_shape->dim(0).dim_value() == 1); + bool is_single_value_1d_tensor = dim_size == 1 && utils::HasDimValue(indices_shape->dim(0)) && + indices_shape->dim(0).dim_value() == 1; if (dim_size != 0 && !is_single_value_1d_tensor) { if (dim_size == 1 && utils::HasDimValue(data_shape->dim(axis)) && data_shape->dim(axis).dim_value() > indices_shape->dim(0).dim_value()) { diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.cc index 716b027068ba1..23f7c45fba4ba 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.cc @@ -3,6 +3,7 @@ #ifdef ENABLE_TRAINING +#include #include "core/optimizer/utils.h" #include "core/optimizer/compute_optimizer/upstream_reshape_actors.h" @@ -282,6 +283,23 @@ bool LayerNormalizationReshapeActor::PreCheck( return propagate_input_indices.size() > 0; } +bool LayerNormalizationReshapeActor::PostProcess( + Graph& /* graph */, Node& current_node, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) { + auto axis = static_cast(current_node.GetAttributes().at("axis").i()); + // When Reshape(from 3D to 2D, with the first two dimensions be merged) upstream a LayerNormalization, + // The axis attribute of LayerNormalization should be decreased by 1 if it is greater than 1. + if (axis > 1) { + auto new_axis = axis - 1; + auto& attributes = current_node.GetMutableAttributes(); + attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", static_cast(new_axis)); + } + return true; +} + template class SimplePointwiseReshapeActor; template class SimplePointwiseReshapeActor; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.h b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.h index 05bcbabe9ba4c..de50a56fd8781 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.h +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape_actors.h @@ -111,13 +111,11 @@ class UpStreamReshapeOperatorActorBase : public UpStreamOperatorActorBase { * So far, we don't have requirements to override PostProcess function. */ - bool PostProcess(Graph& /* graph */, Node& /* current_node */, const ReshapeInfo& /* info_without_node */, - const logging::Logger& /* logger */, - std::vector& /* propagate_input_indices */, - const std::unordered_map>& /* all_input_cmp_rets */, - const std::unordered_map& /* new_reshape_infos */) { - return true; - } + virtual bool PostProcess(Graph& /* graph */, Node& /* current_node */, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) = 0; }; // The inputs are broad-cast-able. The outputs should have the same shape (fully broadcasted shape) @@ -133,6 +131,14 @@ class SimplePointwiseReshapeActor : public UpStreamReshapeOperatorActorBase { std::vector& propagate_input_indices, std::unordered_map>& all_input_cmp_rets, std::function& shape_update_func) override; + + bool PostProcess(Graph& /* graph */, Node& /* current_node */, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) override { + return true; + } }; class MatMulReshapeActor : public UpStreamReshapeOperatorActorBase { @@ -145,6 +151,14 @@ class MatMulReshapeActor : public UpStreamReshapeOperatorActorBase { std::vector& propagate_input_indices, std::unordered_map>& all_input_cmp_rets, std::function& shape_update_func) override; + + bool PostProcess(Graph& /* graph */, Node& /* current_node */, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) override { + return true; + } }; class LayerNormalizationReshapeActor : public UpStreamReshapeOperatorActorBase { @@ -157,6 +171,12 @@ class LayerNormalizationReshapeActor : public UpStreamReshapeOperatorActorBase { std::vector& propagate_input_indices, std::unordered_map>& all_input_cmp_rets, std::function& shape_update_func) override; + + bool PostProcess(Graph& /* graph */, Node& current_node, const ReshapeInfo& /* info_without_node */, + const logging::Logger& /* logger */, + std::vector& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_reshape_infos */) override; }; /** diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 5015e48fdb7b8..3880288bdba2e 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -443,7 +443,6 @@ bool InstanceAndLayerNormalizationNodeGroupSelector::Check(const GraphViewer& gr } int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - int32_t dt_scale = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); int32_t dt_bias = 0; bool has_bias = false; // bias is optional for LayerNorm @@ -453,9 +452,9 @@ bool InstanceAndLayerNormalizationNodeGroupSelector::Check(const GraphViewer& gr } int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - // Input, output, and scale need to be the same type. The bias is int32. + // Input, output, need to be the same type. The bias is int32. + // Scale can be different with input for a16w8 case return (dt_input == dt_output) && - (dt_input == dt_scale) && (has_bias ? dt_bias == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32 : true); } diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 60e0b1c061a43..4a6743e9e5c52 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -169,6 +170,60 @@ Status CreateInputFeatureProvider(const std::unordered_map mlmultiarray_buffer_size) { + if (mlmultiarray_buffer == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "mlmultiarray_buffer has no data"); + } + + const size_t num_elements = array_info.count; + const auto onnx_data_type = tensor_info->data_type; + switch (onnx_data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + const auto output_data_byte_size = num_elements * sizeof(float); + ORT_RETURN_IF_NOT(!mlmultiarray_buffer_size || mlmultiarray_buffer_size == output_data_byte_size, + "CoreML output buffer size and expected output size differ"); + memcpy(tensor_buffer, mlmultiarray_buffer, output_data_byte_size); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT32: { + const auto output_data_byte_size = num_elements * sizeof(int32_t); + ORT_RETURN_IF_NOT(!mlmultiarray_buffer_size || mlmultiarray_buffer_size == output_data_byte_size, + "CoreML output buffer size and expected output size differ"); + memcpy(tensor_buffer, mlmultiarray_buffer, output_data_byte_size); + break; + } + // For this case, since Coreml Spec only uses int32 for model output while onnx provides + // int64 for model output data type. We are doing a type casting (int32 -> int64) here + // when copying the model to ORT + case ONNX_NAMESPACE::TensorProto_DataType_INT64: { + ORT_RETURN_IF_NOT(array_info.dataType == MLMultiArrayDataTypeInt32, + "CoreML output data type is not MLMultiArrayDataTypeInt32"); + ORT_RETURN_IF_NOT(!mlmultiarray_buffer_size || mlmultiarray_buffer_size == num_elements * sizeof(int32_t), + "CoreML output buffer size and expected output size differ"); + const auto model_output_span = gsl::span{static_cast(mlmultiarray_buffer), num_elements}; + const auto output_span = gsl::span{static_cast(tensor_buffer), num_elements}; + std::transform(model_output_span.begin(), model_output_span.end(), output_span.begin(), + [](int32_t v) { return static_cast(v); }); + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Output data type is not supported, actual type: ", onnx_data_type); + } + return Status::OK(); +} } // namespace NS_ASSUME_NONNULL_BEGIN @@ -298,9 +353,9 @@ - (Status)predict:(const std::unordered_map&)inputs return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "output_features has no value for ", output_name); } - auto* data = [output_value multiArrayValue]; + MLMultiArray* data = [output_value multiArrayValue]; - const auto coreml_static_output_shape = [&]() { + const auto coreml_static_output_shape = [data]() { InlinedVector result; result.reserve(data.shape.count); for (NSNumber* dim in data.shape) { @@ -324,41 +379,21 @@ - (Status)predict:(const std::unordered_map&)inputs ") do not match"); } - const void* model_output_buffer = data.dataPointer; - - if (model_output_buffer == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model_output_buffer has no data for ", output_name); - } - - const auto onnx_data_type = output_tensor_info.data_type; - switch (onnx_data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { - const auto output_data_byte_size = num_elements * sizeof(float); - memcpy(output_buffer, model_output_buffer, output_data_byte_size); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_INT32: { - const auto output_data_byte_size = num_elements * sizeof(int32_t); - memcpy(output_buffer, model_output_buffer, output_data_byte_size); - break; - } - // For this case, since Coreml Spec only uses int32 for model output while onnx provides - // int64 for model output data type. We are doing a type casting (int32 -> int64) here - // when copying the model to ORT - case ONNX_NAMESPACE::TensorProto_DataType_INT64: { - ORT_RETURN_IF_NOT(data.dataType == MLMultiArrayDataTypeInt32, - "CoreML output data type is not MLMultiArrayDataTypeInt32"); - - const auto model_output_span = gsl::span{static_cast(model_output_buffer), num_elements}; - const auto output_span = gsl::span{static_cast(output_buffer), num_elements}; - std::transform(model_output_span.begin(), model_output_span.end(), output_span.begin(), - [](int32_t v) { return static_cast(v); }); - break; - } - default: - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Output data type is not supported, actual type: ", onnx_data_type); + ORT_RETURN_IF_NOT(IsArrayContiguous(data), + "Non-contiguous output MLMultiArray is not currently supported"); + __block Status copy_status; + const auto* tensor_info = &output_tensor_info; + // `getBytesWithHandler` replaces deprecated `.dataPointer` on new versions + if (@available(macOS 12.3, iOS 15.4, *)) { + [data getBytesWithHandler:^(const void* bytes, NSInteger size) { + copy_status = CopyMLMultiArrayBuffer(bytes, output_buffer, data, tensor_info, size); + }]; + } else { + // disable size check as old API does not return buffer length + copy_status = CopyMLMultiArrayBuffer(data.dataPointer, output_buffer, data, tensor_info, std::nullopt); } + if (!copy_status.IsOK()) + return copy_status; } } } diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc index ce834e371fdef..3c83394fb0bf4 100644 --- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc @@ -688,21 +688,23 @@ FastReduceKind OptimizeShapeForFastReduce(gsl::span input_shape, return FastReduceKind::kNone; } -void ValidateCommonFastReduce(const Tensor* axes_tensor) { - ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); - ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, - "An axes tensor must be a vector tensor."); -} - // template bool CommonFastReduceCopy(OpKernelContext* ctx, TensorShapeVector& input_axes, bool noop_with_empty_axes) { if (ctx->InputCount() == 2) { // second input holds the axes. + // the argument is optional const Tensor* axes_tensor = ctx->Input(1); - ValidateCommonFastReduce(axes_tensor); - auto nDims = static_cast(axes_tensor->Shape()[0]); - const auto* data = axes_tensor->Data(); - input_axes.insert(input_axes.begin(), data, data + nDims); + + if (axes_tensor != nullptr) { + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a vector tensor."); + + const auto data_span = axes_tensor->DataAsSpan(); + input_axes.assign(data_span.begin(), data_span.end()); + } else { + input_axes.clear(); + } + if (input_axes.empty() && noop_with_empty_axes) { const Tensor* input = ctx->Input(0); auto* output = ctx->Output(0, input->Shape()); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc index 4ae59951c5e98..fdc5317419c5b 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/simple_op_builder.cc @@ -22,6 +22,11 @@ class SimpleOpBuilder : public BaseOpBuilder { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SimpleOpBuilder); protected: + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, @@ -48,6 +53,90 @@ class SimpleOpBuilder : public BaseOpBuilder { static constexpr std::array gridsample_supported_padding_modes = {"zeros", "border", "reflection"}; }; +// Move to qnn_utils if it's re-usable +Status InsertConvertOp(QnnModelWrapper& qnn_model_wrapper, + const std::string& convert_input_name, + const std::string& convert_output_name, + Qnn_DataType_t input_qnn_data_type, + Qnn_DataType_t output_qnn_data_type, + int32_t input_offset, + float input_scale, + const std::vector& output_shape, + bool do_op_validation) { + // Assume input is already handled. + float qmin = 0.0f; + float qmax = 255.0f; + ORT_RETURN_IF_ERROR(qnn::utils::GetQminQmax(input_qnn_data_type, qmin, qmax)); + double value_min = qnn::utils::Dequantize(input_offset, input_scale, qmin); + double value_max = qnn::utils::Dequantize(input_offset, input_scale, qmax); + + Qnn_QuantizeParams_t convert_output_quant_param = QNN_QUANTIZE_PARAMS_INIT; + convert_output_quant_param.encodingDefinition = QNN_DEFINITION_DEFINED; + convert_output_quant_param.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + ORT_RETURN_IF_ERROR(qnn::utils::GetQuantParams(static_cast(value_min), + static_cast(value_max), + output_qnn_data_type, + convert_output_quant_param.scaleOffsetEncoding.scale, + convert_output_quant_param.scaleOffsetEncoding.offset)); + + std::vector output_shape_copy = output_shape; + QnnTensorWrapper convert_output_tensorwrapper(convert_output_name, + QNN_TENSOR_TYPE_NATIVE, + output_qnn_data_type, + convert_output_quant_param, + std::move(output_shape_copy)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(convert_output_tensorwrapper)), "Failed to add tensor."); + + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(convert_output_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + "Convert", + {convert_input_name}, + {convert_output_name}, + {}, + do_op_validation), + "Failed to add node."); + return Status::OK(); +} + +Status SimpleOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const { + const std::string& op_type = node_unit.OpType(); + ORT_RETURN_IF_ERROR(BaseOpBuilder::ProcessInputs(qnn_model_wrapper, node_unit, logger, input_names, do_op_validation)); + + if (op_type == "MatMul") { + const auto& inputs = node_unit.Inputs(); + TensorInfo input0_info = {}; + TensorInfo input1_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[0], input0_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], input1_info)); + // Need to insert Convert op if both inputs are dynamic inputs and are ufixed_16 + if (!input0_info.is_initializer && !input1_info.is_initializer && + input0_info.qnn_data_type == input1_info.qnn_data_type && + input0_info.qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) { + // insert Convert op after input1 + std::string convert_input_name = input_names.back(); + input_names.pop_back(); + const std::string& matmul_output_name = node_unit.Outputs()[0].node_arg.Name(); + std::string convert_output_name = convert_input_name + "_convert_" + matmul_output_name; + ORT_RETURN_IF_ERROR(InsertConvertOp(qnn_model_wrapper, + convert_input_name, + convert_output_name, + input1_info.qnn_data_type, + QNN_DATATYPE_UFIXED_POINT_8, + input1_info.quant_param.scaleOffsetEncoding.offset, + input1_info.quant_param.scaleOffsetEncoding.scale, + input1_info.shape, + do_op_validation)); + input_names.push_back(convert_output_name); + } + } + + return Status::OK(); +} + Status SimpleOpBuilder::ExplicitOpCheck(const NodeUnit& node_unit) const { const std::string& op_type = node_unit.OpType(); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 020af451cdcd5..79f84864a5788 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1194,6 +1194,11 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() { } } + if (external_stream_) { + ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasDestroy(external_cublas_handle_))); + ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnDestroy(external_cudnn_handle_))); + } + if (!external_stream_ && stream_) { ORT_IGNORE_RETURN_VALUE(CUDA_CALL(cudaStreamDestroy(stream_))); } @@ -1272,6 +1277,20 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { return Status::OK(); } +// Get the pointer to the IBuilder instance. +// Note: This function is not thread safe. Calls to this function from different threads must be serialized +// even though it doesn't make sense to have multiple threads initializing the same inference session. +nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const { + if (!builder_) { + TensorrtLogger& trt_logger = GetTensorrtLogger(); + { + auto lock = GetApiLock(); + builder_ = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + } + } + return builder_.get(); +} + void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector& custom_op_domain_list) const { if (info_.custom_op_domain_list.empty()) { common::Status status = CreateTensorRTCustomOpDomainList(info_); @@ -1633,7 +1652,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect // Get supported node list recursively SubGraphCollection_t parser_nodes_list; TensorrtLogger& trt_logger = GetTensorrtLogger(); - auto trt_builder = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + auto trt_builder = GetBuilder(); const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(explicitBatch)); @@ -1810,6 +1829,10 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, if (sub_graphs.size() != 0) { bool all_subgraphs_are_supported = true; for (auto sub_graph : sub_graphs) { + // TRT EP should consider the empty subgraph is fully supported by TRT. + if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) { + continue; + } if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) { all_subgraphs_are_supported = false; break; @@ -1877,27 +1900,33 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, auto sub_graphs = graph.ParentNode()->GetSubgraphs(); for (auto sub_graph : sub_graphs) { if (sub_graph.get() != &graph.GetGraph()) { - auto sub_graph_veiwer = sub_graph->CreateGraphViewer(); - const int number_of_ort_subgraph_nodes = sub_graph_veiwer->NumberOfNodes(); + auto sub_graph_viewer = sub_graph->CreateGraphViewer(); + const int number_of_ort_subgraph_nodes = sub_graph_viewer->NumberOfNodes(); std::vector subgraph_nodes_vector(number_of_ort_subgraph_nodes); std::iota(std::begin(subgraph_nodes_vector), std::end(subgraph_nodes_vector), 0); SubGraphCollection_t parser_subgraph_nodes_vector = {{subgraph_nodes_vector, false}}; bool subgraph_early_termination = false; - // Another subgraph of "If" control flow has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP. - if (AllNodesAssignedToSpecificEP(*sub_graph_veiwer, kTensorrtExecutionProvider)) { + // Another subgraph of "If" control flow op has no nodes. + // In this case, TRT EP should consider this empty subgraph is fully supported by TRT. + if (sub_graph_viewer->NumberOfNodes() == 0) { + all_subgraphs_are_supported = true; + break; + } + // Another subgraph of "If" control flow op has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP. + else if (AllNodesAssignedToSpecificEP(*sub_graph_viewer, kTensorrtExecutionProvider)) { all_subgraphs_are_supported = true; break; } // Another subgraph of "If" control flow has been parsed by GetCapability and not all subgraph's nodes assigned to TRT EP. // (Note: GetExecutionProviderType() returns "" meaning node has not yet been assigned to any EPs) - else if (!AllNodesAssignedToSpecificEP(*sub_graph_veiwer, "")) { + else if (!AllNodesAssignedToSpecificEP(*sub_graph_viewer, "")) { all_subgraphs_are_supported = false; break; } // Another subgraph of "If" control flow has not yet been parsed by GetCapability. - subgraph_supported_nodes_vector = GetSupportedList(parser_subgraph_nodes_vector, 0, max_partition_iterations_, *sub_graph_veiwer, &subgraph_early_termination); + subgraph_supported_nodes_vector = GetSupportedList(parser_subgraph_nodes_vector, 0, max_partition_iterations_, *sub_graph_viewer, &subgraph_early_termination); all_subgraphs_are_supported = IsSubGraphFullySupported(subgraph_supported_nodes_vector, number_of_ort_subgraph_nodes); break; } @@ -1985,7 +2014,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(nvinfer1::createInferBuilder(trt_logger)); + auto trt_builder = GetBuilder(); const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(explicitBatch)); auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); @@ -2438,7 +2467,6 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorallocate_func, context->release_func, context->allocator_handle, context->node_name, - &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], + *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(), + &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, @@ -2490,7 +2518,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorsync_stream_after_enqueue; auto fused_node_name = trt_state->fused_node_name; auto& shape_ranges = trt_state->input_shape_ranges; - auto trt_builder = trt_state->builder->get(); + auto trt_builder = trt_state->builder; auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); auto trt_profiles = trt_state->profiles; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index cda08715ea009..a945d219088aa 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -105,10 +105,10 @@ struct TensorrtFuncState { DestroyFunc test_release_func = nullptr; AllocatorHandle allocator = nullptr; std::string fused_node_name; + nvinfer1::IBuilder* builder; tensorrt_ptr::unique_pointer* parser = nullptr; std::unique_ptr* engine = nullptr; std::unique_ptr* context = nullptr; - std::unique_ptr* builder = nullptr; std::unique_ptr* network = nullptr; std::vector> input_info; std::vector> output_info; @@ -245,6 +245,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; mutable std::unordered_map> subgraph_context_map_; + mutable std::unique_ptr builder_; + // Following maps that hold TRT objects will be accessible by different threads if ORT is using multithreading. // In general, TensorRT objects are not thread safe; accesses to an object from different threads must be serialized by the client. // But there are still some thread safe operations, please see here https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading @@ -456,5 +458,11 @@ class TensorrtExecutionProvider : public IExecutionProvider { void CaptureBegin(); void CaptureEnd(); void IncrementRegularRunCountBeforeGraphCapture(); + + /** + * Get the pointer to the IBuilder instance. + * This function only creates the instance at the first time it's being called." + */ + nvinfer1::IBuilder* GetBuilder() const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 46c456556e016..8ae16f0dd21fc 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -156,6 +156,7 @@ static const InlinedHashMap op_map = { {"GlobalMaxPool", "maxPool2d"}, {"GlobalLpPool", "l2Pool2d"}, {"Greater", "greater"}, + {"GreaterOrEqual", "greaterOrEqual"}, {"GroupNormalization", "meanVarianceNormalization"}, {"HardSigmoid", "hardSigmoid"}, {"HardSwish", "hardSwish"}, @@ -164,6 +165,7 @@ static const InlinedHashMap op_map = { {"LayerNormalization", "meanVarianceNormalization"}, {"LeakyRelu", "leakyRelu"}, {"Less", "lesser"}, + {"LessOrEqual", "lesserOrEqual"}, {"Log", "log"}, {"LpPool", "l2Pool2d"}, {"MatMul", "matmul"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index 4cb49d8f8cd3a..c8f58fa98635f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -35,8 +35,12 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons output = model_builder.GetBuilder().call("equal", input0, input1); } else if (op_type == "Greater") { output = model_builder.GetBuilder().call("greater", input0, input1); + } else if (op_type == "GreaterOrEqual") { + output = model_builder.GetBuilder().call("greaterOrEqual", input0, input1); } else if (op_type == "Less") { output = model_builder.GetBuilder().call("lesser", input0, input1); + } else if (op_type == "LessOrEqual") { + output = model_builder.GetBuilder().call("lesserOrEqual", input0, input1); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "LogicalOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); @@ -54,7 +58,9 @@ void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& { "Equal", "Greater", + "GreaterOrEqual", "Less", + "LessOrEqual", }; op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index 8778bb2414108..e48cf35012652 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -114,6 +114,22 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, if (!GetShape(*input_defs[0], input_shape, logger)) { return false; } + + if (input_defs.size() < 3) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 3 inputs (data, starts, ends) but got " + << input_defs.size(); + return false; + } + + // Inputs: starts, ends, axes, and steps must be constant initializers if present. + for (size_t i = 1; i < input_defs.size(); i++) { + if (!Contains(initializers, input_defs[i]->Name())) { + LOGS(logger, VERBOSE) << "Input [" << input_defs[i]->Name() << "] of " << op_type + << " [" << name << "] must be known as initializer"; + return false; + } + } + if (input_defs.size() == 5) { // Check steps. const auto& steps_tensor = *initializers.at(input_defs[4]->Name()); std::vector unpacked_tensor; @@ -140,18 +156,6 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } } - if (input_defs.size() < 3) { - LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 3 inputs (data starts and ends) but got " - << input_defs.size(); - return false; - } - - const auto& starts_name = input_defs[1]->Name(); - const auto& ends_name = input_defs[2]->Name(); - if (!Contains(initializers, starts_name) || !Contains(initializers, ends_name)) { - LOGS(logger, VERBOSE) << op_type << " [" << name << "] need starts and ends as initializer."; - return false; - } return true; } diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 65dc8ddbeaf90..463317a4dafda 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -99,7 +99,9 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { { // Logical CreateLogicalOpBuilder("Equal", op_registrations); CreateLogicalOpBuilder("Greater", op_registrations); + CreateLogicalOpBuilder("GreaterOrEqual", op_registrations); CreateLogicalOpBuilder("Less", op_registrations); + CreateLogicalOpBuilder("LessOrEqual", op_registrations); } { // Max/Min diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index ccedc71b9119a..f02d180ab104f 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2025,9 +2025,10 @@ common::Status InferenceSession::ValidateInputsOutputs(gsl::spansecond.tensor_shape.has_value()) { + const auto& opt_shape = iter->second.tensor_shape; + if (opt_shape.has_value() && !opt_shape->GetDims().empty()) { ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(name, input_output_tensor.Shape(), - *iter->second.tensor_shape, input_output_moniker)); + *opt_shape, input_output_moniker)); } } else if (input_output_ml_value.IsSparseTensor()) { #if !defined(DISABLE_SPARSE_TENSORS) @@ -2038,9 +2039,10 @@ common::Status InferenceSession::ValidateInputsOutputs(gsl::spansecond.tensor_shape.has_value()) { + const auto& opt_shape = iter->second.tensor_shape; + if (opt_shape.has_value() && !opt_shape->GetDims().empty()) { ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(name, sparse_tensor.DenseShape(), - *iter->second.tensor_shape, input_output_moniker)); + *opt_shape, input_output_moniker)); } } else if (is_sparse_initializer(name) && expected_type->IsTensorType()) { @@ -2049,9 +2051,10 @@ common::Status InferenceSession::ValidateInputsOutputs(gsl::spansecond.tensor_shape.has_value()) { + const auto& opt_shape = iter->second.tensor_shape; + if (opt_shape.has_value() && !opt_shape->GetDims().empty()) { ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(name, sparse_tensor.DenseShape(), - *iter->second.tensor_shape, input_output_moniker)); + *opt_shape, input_output_moniker)); } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, input_output_moniker, " with name: '", name, @@ -2061,7 +2064,6 @@ common::Status InferenceSession::ValidateInputsOutputs(gsl::spanIsTensorSequenceType() #if !defined(DISABLE_OPTIONAL_TYPE) diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index c1b241aa1a5ec..d11cb91d98b0c 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -657,7 +657,6 @@ def create_multihead_attention_node( return None graph_input_names = set([node.name for node in self.model.graph().input]) - graph_output_names = set([node.name for node in self.model.graph().output]) mha_node_name = self.model.create_node_name("Attention") # Add initial Q/K/V inputs for MHA @@ -693,12 +692,15 @@ def create_multihead_attention_node( mha_inputs.append("") # Add optional inputs for MHA - if past_k and past_v and past_k in graph_input_names and past_v in graph_input_names: + + if past_k and past_v: mha_inputs.extend([key_padding_mask, add_qk, past_k, past_v]) + elif key_padding_mask or add_qk: + mha_inputs.extend([key_padding_mask, add_qk]) # Add outputs for MHA mha_outputs = [output] - if present_k and present_v and present_k in graph_output_names and present_v in graph_output_names: + if present_k and present_v: mha_outputs.extend([present_k, present_v]) mha_node = helper.make_node( diff --git a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py new file mode 100644 index 0000000000000..6bc681c57444e --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py @@ -0,0 +1,143 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging + +from fusion_attention import AttentionMask, FusionAttention +from onnx_model import OnnxModel + +logger = logging.getLogger(__name__) + + +class FusionConformerAttention(FusionAttention): + """ + Fuse Conformer Attention subgraph into one MultiHeadAttention node. + """ + + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + attention_mask: AttentionMask, + ): + super().__init__(model, hidden_size, num_heads, attention_mask) + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + # SkipLayerNormalization has two inputs, and one of them is the root input for attention. + qkv_nodes = self.model.match_parent_path( + normalize_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 1, 0, 0, 0], + ) + if qkv_nodes is not None: + ( + _, + _, + reshape_qkv, + transpose_qkv, + matmul_qkv, + ) = qkv_nodes + else: + logger.debug("fuse_conformer_attention: failed to match qkv path") + return + + v_nodes = self.model.match_parent_path( + matmul_qkv, + ["Concat", "Transpose", "Reshape", "Add", "MatMul"], + [1, 1, 0, 0, 1], + ) + + add_v = None + if v_nodes is not None: + (concat_v, _, _, add_v, matmul_v) = v_nodes + concat_parent = self.model.get_parent(concat_v, 0, None) + present_v = concat_v.output[0] + past_v = concat_parent.output[0] + else: + logger.debug("fuse_conformer_attention: failed to match v path") + return + + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) + + if qk_nodes is not None: + _, add_qk, matmul_qk = qk_nodes + else: + logger.debug("fuse_conformer_attention: failed to match qk path") + return + + q_nodes = self.model.match_parent_path( + matmul_qk, + ["Div", "Transpose", "Reshape", "Add", "MatMul"], + [0, 0, 0, 0, 1], + ) + if q_nodes is not None: + _, _, reshape_q, add_q, matmul_q = q_nodes + else: + logger.debug("fuse_conformer_attention: failed to match q path") + return + + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Concat", "Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 1, 0, 0, 1], + ) + + matmul_k = None + if k_nodes is not None: + _, concat_k, _, _, add_k, matmul_k = k_nodes + concat_parent = self.model.get_parent(concat_k, 0, None) + past_k = concat_parent.output[0] + present_k = concat_k.output[0] + else: + logger.debug("fuse_conformer_attention: failed to match k path") + return + + attention_last_node = reshape_qkv + num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q) + + if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0: + logger.debug("fuse_conformer_attention: failed to detect num_heads or hidden_size") + return + + new_node = self.create_multihead_attention_node( + matmul_q, + matmul_k, + matmul_v, + add_q, + add_k, + add_v, + num_heads, + hidden_size, + attention_last_node.output[0], + add_qk=add_qk.input[1], + past_k=past_k, + past_v=past_v, + present_k=present_k, + present_v=present_v, + ) + + if new_node is None: + logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed") + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv]) + self.nodes_to_remove.extend(qk_nodes) + + # When using multihead attention, keep MatMul nodes in original graph + if q_nodes[-1].op_type == "MatMul": + q_nodes.pop() + if k_nodes[-1].op_type == "MatMul": + k_nodes.pop() + if v_nodes[-1].op_type == "MatMul": + v_nodes.pop() + + self.nodes_to_remove.extend(k_nodes) + self.nodes_to_remove.extend(v_nodes) + + # Use prune graph to remove mask nodes since they are shared by all attention nodes. + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py index 974759bb6ae4b..4f9ecf6cbb152 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -53,7 +53,9 @@ def load_pipelines(args, batch_size): max_image_size = 1216 if args.engine != "ORT_CUDA" else 2048 # No VAE decoder in base when it outputs latent instead of image. - base_info = PipelineInfo(args.version, use_vae=False, min_image_size=min_image_size, max_image_size=max_image_size) + base_info = PipelineInfo( + args.version, use_vae=args.disable_refiner, min_image_size=min_image_size, max_image_size=max_image_size + ) # Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to # optimize the shape used most frequently. We can let user config it when we develop a UI plugin. @@ -74,25 +76,28 @@ def load_pipelines(args, batch_size): opt_image_width, ) - refiner_info = PipelineInfo( - args.version, is_refiner=True, min_image_size=min_image_size, max_image_size=max_image_size - ) - refiner = init_pipeline( - Img2ImgXLPipeline, - refiner_info, - engine_type, - args, - max_batch_size, - opt_batch_size, - opt_image_height, - opt_image_width, - ) + refiner = None + if not args.disable_refiner: + refiner_info = PipelineInfo( + args.version, is_refiner=True, min_image_size=min_image_size, max_image_size=max_image_size + ) + refiner = init_pipeline( + Img2ImgXLPipeline, + refiner_info, + engine_type, + args, + max_batch_size, + opt_batch_size, + opt_image_height, + opt_image_width, + ) if engine_type == EngineType.TRT: - max_device_memory = max(base.backend.max_device_memory(), refiner.backend.max_device_memory()) + max_device_memory = max(base.backend.max_device_memory(), (refiner or base).backend.max_device_memory()) _, shared_device_memory = cudart.cudaMalloc(max_device_memory) base.backend.activate_engines(shared_device_memory) - refiner.backend.activate_engines(shared_device_memory) + if refiner: + refiner.backend.activate_engines(shared_device_memory) if engine_type == EngineType.ORT_CUDA: enable_vae_slicing = args.enable_vae_slicing @@ -100,7 +105,7 @@ def load_pipelines(args, batch_size): print("Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4.") enable_vae_slicing = True if enable_vae_slicing: - refiner.backend.enable_vae_slicing() + (refiner or base).backend.enable_vae_slicing() return base, refiner @@ -109,7 +114,8 @@ def run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False image_width = args.width batch_size = len(prompt) base.load_resources(image_height, image_width, batch_size) - refiner.load_resources(image_height, image_width, batch_size) + if refiner: + refiner.load_resources(image_height, image_width, batch_size) def run_base_and_refiner(warmup=False): images, time_base = base.run( @@ -121,8 +127,10 @@ def run_base_and_refiner(warmup=False): denoising_steps=args.denoising_steps, guidance=args.guidance, seed=args.seed, - return_type="latent", + return_type="latent" if refiner else "image", ) + if refiner is None: + return images, time_base # Use same seed in base and refiner. seed = base.get_current_seed() @@ -173,7 +181,8 @@ def run_demo(args): base, refiner = load_pipelines(args, batch_size) run_pipelines(args, base, refiner, prompt, negative_prompt) base.teardown() - refiner.teardown() + if refiner: + refiner.teardown() def run_dynamic_shape_demo(args): @@ -223,7 +232,8 @@ def run_dynamic_shape_demo(args): args.denoising_steps = steps args.seed = seed base.set_scheduler(scheduler) - refiner.set_scheduler(scheduler) + if refiner: + refiner.set_scheduler(scheduler) print( f"\nbatch_size={batch_size}, height={height}, width={width}, scheduler={scheduler}, steps={steps}, prompt={example_prompt}, seed={seed}" ) @@ -231,7 +241,8 @@ def run_dynamic_shape_demo(args): run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False) base.teardown() - refiner.teardown() + if refiner: + refiner.teardown() if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py index ef45b786b9ea3..39ee273a3130d 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -145,6 +145,10 @@ def parse_arguments(is_xl: bool, description: str): parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.") parser.add_argument("--disable-cuda-graph", action="store_true", help="Disable cuda graph.") + parser.add_argument( + "--disable-refiner", action="store_true", help="Disable refiner and only run base for XL pipeline." + ) + group = parser.add_argument_group("Options for ORT_CUDA engine only") group.add_argument("--enable-vae-slicing", action="store_true", help="True will feed only one image to VAE once.") diff --git a/onnxruntime/python/tools/transformers/onnx_model_conformer.py b/onnxruntime/python/tools/transformers/onnx_model_conformer.py new file mode 100644 index 0000000000000..1506d85f53fd4 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_model_conformer.py @@ -0,0 +1,33 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +from typing import Optional + +from fusion_attention import AttentionMask +from fusion_conformer_attention import FusionConformerAttention +from fusion_options import FusionOptions +from onnx_model_bert import BertOnnxModel + +logger = logging.getLogger(__name__) + + +class ConformerOnnxModel(BertOnnxModel): + def __init__(self, model, num_heads, hidden_size): + super().__init__(model, num_heads, hidden_size) + self.attention_mask = AttentionMask(self) + self.attention_fusion = FusionConformerAttention(self, self.hidden_size, self.num_heads, self.attention_mask) + + def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): + self.attention_fusion.use_multi_head_attention = False if options is None else options.use_multi_head_attention + self.attention_fusion.disable_multi_head_attention_bias = ( + False if options is None else options.disable_multi_head_attention_bias + ) + super().optimize(options, add_dynamic_axes) + + def fuse_attention(self): + self.attention_fusion.apply() + + def preprocess(self): + self.adjust_reshape_and_expand() diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 94a757320e598..6842a97fe0c77 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -32,6 +32,7 @@ from onnx_model_bert_keras import BertOnnxModelKeras from onnx_model_bert_tf import BertOnnxModelTF from onnx_model_clip import ClipOnnxModel +from onnx_model_conformer import ConformerOnnxModel from onnx_model_gpt2 import Gpt2OnnxModel from onnx_model_t5 import T5OnnxModel from onnx_model_tnlr import TnlrOnnxModel @@ -56,6 +57,7 @@ "unet": (UnetOnnxModel, "pytorch", 1), # UNet in Stable Diffusion "vae": (VaeOnnxModel, "pytorch", 1), # UAE in Stable Diffusion "vit": (BertOnnxModel, "pytorch", 1), + "conformer": (ConformerOnnxModel, "pytorch", 1), } diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index 84d8a9c56df89..9ab78cac3aca4 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -589,5 +589,30 @@ TEST(FunctionTest, TestInlinedLocalFunctionNotRemoved) { #endif } +TEST(FunctionTest, TestInlinedFunctionDoesNotReserrectNonExistingArgs) { + // Verify this runs + constexpr const ORTCHAR_T* model_uri = ORT_TSTR("testdata/transform/gh_issue_18338.onnx"); + + SessionOptions session_options; + InferenceSessionWrapper session_object{session_options, GetEnvironment()}; + + ASSERT_STATUS_OK(session_object.Load(model_uri)); + ASSERT_STATUS_OK(session_object.Initialize()); + + // Scalar shape for input_0 and output + const std::string input_names[] = {"input_0"}; + const std::string output_names[] = {"_val_3"}; + TensorShape input_shape; + MLFloat16 input_0_data{684.f}; + + OrtValue input_0; + Tensor::InitOrtValue(DataTypeImpl::GetType(), input_shape, &input_0_data, OrtMemoryInfo(), input_0); + + std::vector fetches(1); + RunOptions run_options; + ASSERT_STATUS_OK(session_object.Run(run_options, AsSpan(input_names), AsSpan({input_0}), + AsSpan(output_names), &fetches, 0)); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/compute_optimizer_test.cc b/onnxruntime/test/optimizer/compute_optimizer_test.cc index a03d0da2538d4..9dcedd1fd7681 100644 --- a/onnxruntime/test/optimizer/compute_optimizer_test.cc +++ b/onnxruntime/test/optimizer/compute_optimizer_test.cc @@ -847,7 +847,7 @@ Test graph includes multiple equivalent subgraphs as below. Add an Identity node because currently, we don't allow Gather generates graph output. */ TEST(ComputeOptimizerTests, GatherLayerNormalization) { - std::vector> test_config_pairs{ + std::vector> test_config_pairs{ // { // is_scalar_slice, // ln_axis_before_propagation, @@ -929,13 +929,6 @@ TEST(ComputeOptimizerTests, GatherLayerNormalization) { const ONNX_NAMESPACE::TensorShapeProto* slice_out_shape = producer_node->OutputDefs()[0]->Shape(); TEST_RETURN_IF_NOT(slice_out_shape != nullptr); - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - - auto& axis_attr = attrs.at("axis"); - auto axis_value = (int)axis_attr.i(); - TEST_RETURN_IF_NOT(axis_value == ln_axis_after); - if (is_scalar_slice) { TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 2); TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && @@ -951,10 +944,15 @@ TEST(ComputeOptimizerTests, GatherLayerNormalization) { TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(2)) && slice_out_shape->dim(2).dim_value() == 256); } - } else { TEST_RETURN_IF_NOT(producer_node == nullptr); } + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + TEST_RETURN_IF_NOT(axis_value == ln_axis_after); } } @@ -2841,165 +2839,110 @@ Test graph include multiple equivalent subgraphs as below. Add an Identity node because currently we don't allow Reshape generate graph output. */ -TEST(ComputeOptimizerTests, ReshapeLayerNormalization_PropagationOnOneBranch) { - const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); - auto pre_graph_checker = [](Graph& graph) -> Status { - auto op_count_pre = CountOpsInGraph(graph); - TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); - TEST_RETURN_IF_NOT(op_count_pre["LayerNormalization"] == 1); - TEST_RETURN_IF_NOT(op_count_pre["Reshape"] == 1); - TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); - return Status::OK(); - }; - - auto post_graph_checker = [](Graph& graph) { - auto op_count_post = CountOpsInGraph(graph); - TEST_RETURN_IF_NOT(op_count_post.size() == 3U); - TEST_RETURN_IF_NOT(op_count_post["LayerNormalization"] == 1); - TEST_RETURN_IF_NOT(op_count_post["Reshape"] == 1); - TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); - - for (Node& node : graph.Nodes()) { - if (node.OpType() == "LayerNormalization") { - const auto& input_defs = node.InputDefs(); - - { - auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); - TEST_RETURN_IF_NOT(producer_node != nullptr); - TEST_RETURN_IF_NOT(producer_node->OpType() == "Reshape"); - - InlinedVector values; - constexpr bool require_constant = true; - NodeArg* initializer_node_arg = graph.GetNodeArg(producer_node->InputDefs()[1]->Name()); - TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, values, require_constant)); - TEST_RETURN_IF_NOT(values.size() == 2); - TEST_RETURN_IF_NOT(values[0] == -1); - TEST_RETURN_IF_NOT(values[1] == 1024); - } - - { - auto producer_node = graph.GetProducerNode(input_defs[1]->Name()); - TEST_RETURN_IF_NOT(producer_node == nullptr); - } - - { - auto producer_node = graph.GetProducerNode(input_defs[2]->Name()); - TEST_RETURN_IF_NOT(producer_node == nullptr); - } - } - } - return Status::OK(); +TEST(ComputeOptimizerTests, ReshapeLayerNormalization) { + std::vector> test_config_pairs{ + // { + // ln_axis_before_propagation, + // expected_ln_axis_after_propagation, + // expected to propagate + // } + {0, 0, false}, + {1, 1, false}, + {2, 1, true}, + {-3, -3, false}, + {-2, -2, false}, + {-1, -1, true}, }; - std::vector fist_dim_values = {-1, 128}; - for (auto first_dim_value : fist_dim_values) { - auto build_test_case = [&first_dim_value](ModelTestBuilder& builder) { - auto* input1_arg = builder.MakeInput({{4, 32, 1024}}); - auto* input2_arg = builder.MakeInput({{1024}}); - auto* input3_arg = builder.MakeInput({{1024}}); - auto* ln_out = builder.MakeIntermediate(); - builder.AddNode("LayerNormalization", {input1_arg, input2_arg, input3_arg}, {ln_out}) - .AddAttribute("axis", static_cast(-1)); - - auto* shape_initializer = builder.MakeInitializer({2}, {first_dim_value, 1024}); - auto* reshape_out = builder.MakeIntermediate(); - builder.AddNode("Reshape", {ln_out, shape_initializer}, {reshape_out}); + for (auto p : test_config_pairs) { + int64_t ln_axis_before = std::get<0>(p); + int64_t ln_axis_after = std::get<1>(p); + bool expected_to_propagate = std::get<2>(p); - auto* identity_out = builder.MakeOutput(); - builder.AddNode("Identity", {reshape_out}, {identity_out}); + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + auto pre_graph_checker = [](Graph& graph) -> Status { + auto op_count_pre = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); + TEST_RETURN_IF_NOT(op_count_pre["LayerNormalization"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Reshape"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); + return Status::OK(); }; - const std::vector opsets{12, 13, 14}; - for (auto& opset_version : opsets) { - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger, std::move(transformer), - TransformerLevel::Level1, - 1, pre_graph_checker, post_graph_checker)); - } - } -} + auto post_graph_checker = [ln_axis_after, expected_to_propagate](Graph& graph) { + auto op_count_post = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_post.size() == 3U); + TEST_RETURN_IF_NOT(op_count_post["LayerNormalization"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Reshape"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); -/* -Test graph include multiple equivalent subgraphs as below. - graph input [4, 32, 1024] (float) graph input [1024] (float) graph input [1024] (float) - | | / - \_____________ _______/ __________________________/ - \ / / - LayerNormalization - | - Reshape - | - Identity - | - graph out [128, 1024] (float) + for (Node& node : graph.Nodes()) { + if (node.OpType() == "LayerNormalization") { + const auto& input_defs = node.InputDefs(); -Add an Identity node because currently we don't allow Reshape generate graph output. -*/ -TEST(ComputeOptimizerTests, ReshapeLayerNormalization_NoPropagation) { - const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); - auto pre_graph_checker = [](Graph& graph) -> Status { - auto op_count_pre = CountOpsInGraph(graph); - TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); - TEST_RETURN_IF_NOT(op_count_pre["LayerNormalization"] == 1); - TEST_RETURN_IF_NOT(op_count_pre["Reshape"] == 1); - TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); - return Status::OK(); - }; + if (expected_to_propagate) { + auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); + TEST_RETURN_IF_NOT(producer_node != nullptr); + TEST_RETURN_IF_NOT(producer_node->OpType() == "Reshape"); - auto post_graph_checker = [](Graph& graph) { - auto op_count_post = CountOpsInGraph(graph); - TEST_RETURN_IF_NOT(op_count_post.size() == 3U); - TEST_RETURN_IF_NOT(op_count_post["LayerNormalization"] == 1); - TEST_RETURN_IF_NOT(op_count_post["Reshape"] == 1); - TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); + InlinedVector values; + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(producer_node->InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, values, require_constant)); + TEST_RETURN_IF_NOT(values.size() == 2); + TEST_RETURN_IF_NOT(values[0] == -1); + TEST_RETURN_IF_NOT(values[1] == 1024); + } else { + auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); + TEST_RETURN_IF_NOT(producer_node == nullptr); + } - for (Node& node : graph.Nodes()) { - if (node.OpType() == "LayerNormalization") { - const auto& input_defs = node.InputDefs(); + { + auto producer_node = graph.GetProducerNode(input_defs[1]->Name()); + TEST_RETURN_IF_NOT(producer_node == nullptr); + } - { - auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); - TEST_RETURN_IF_NOT(producer_node == nullptr); - } + { + auto producer_node = graph.GetProducerNode(input_defs[2]->Name()); + TEST_RETURN_IF_NOT(producer_node == nullptr); + } - { - auto producer_node = graph.GetProducerNode(input_defs[1]->Name()); - TEST_RETURN_IF_NOT(producer_node == nullptr); - } + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - { - auto producer_node = graph.GetProducerNode(input_defs[2]->Name()); - TEST_RETURN_IF_NOT(producer_node == nullptr); + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + TEST_RETURN_IF_NOT(axis_value == ln_axis_after); } } - } - return Status::OK(); - }; - - std::vector fist_dim_values = {-1, 128}; - for (auto first_dim_value : fist_dim_values) { - auto build_test_case = [&first_dim_value](ModelTestBuilder& builder) { - auto* input1_arg = builder.MakeInput({{4, 32, 1024}}); - auto* input2_arg = builder.MakeInput({{1024}}); - auto* input3_arg = builder.MakeInput({{1024}}); - auto* ln_out = builder.MakeIntermediate(); - builder.AddNode("LayerNormalization", {input1_arg, input2_arg, input3_arg}, {ln_out}) - .AddAttribute("axis", static_cast(1)); - - auto* shape_initializer = builder.MakeInitializer({2}, {first_dim_value, 1024}); - auto* reshape_out = builder.MakeIntermediate(); - builder.AddNode("Reshape", {ln_out, shape_initializer}, {reshape_out}); - - auto* identity_out = builder.MakeOutput(); - builder.AddNode("Identity", {reshape_out}, {identity_out}); + return Status::OK(); }; - const std::vector opsets{12, 13, 14}; - for (auto& opset_version : opsets) { - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger, std::move(transformer), - TransformerLevel::Level1, - 1, pre_graph_checker, post_graph_checker)); + std::vector fist_dim_values = {-1, 128}; + for (auto first_dim_value : fist_dim_values) { + auto build_test_case = [ln_axis_before, &first_dim_value](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{4, 32, 1024}}); + auto* input2_arg = builder.MakeInput({{1024}}); + auto* input3_arg = builder.MakeInput({{1024}}); + auto* ln_out = builder.MakeIntermediate(); + builder.AddNode("LayerNormalization", {input1_arg, input2_arg, input3_arg}, {ln_out}) + .AddAttribute("axis", ln_axis_before); + + auto* shape_initializer = builder.MakeInitializer({2}, {first_dim_value, 1024}); + auto* reshape_out = builder.MakeIntermediate(); + builder.AddNode("Reshape", {ln_out, shape_initializer}, {reshape_out}); + + auto* identity_out = builder.MakeOutput(); + builder.AddNode("Identity", {reshape_out}, {identity_out}); + }; + + const std::vector opsets{12, 13, 14}; + for (auto& opset_version : opsets) { + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger, std::move(transformer), + TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); + } } } } diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index 3073dde9d8e4c..3da3dc858175b 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -142,11 +142,6 @@ TEST_F(QnnHTPBackendTests, MatMulOp_HTP_u8) { } // Test QDQ MatMul with 16-bit act, 8-bit weights (static) -// TODO: (SLIGHT) Inaccuracy detected for output 'output', element 0. -// Output quant params: scale=0.0015259021893143654, zero_point=0. -// Expected val: 98 -// QNN QDQ val: 97.720298767089844 (err 0.27970123291015625) -// CPU QDQ val: 97.726402282714844 (err 0.27359771728515625) TEST_F(QnnHTPBackendTests, MatMulOp_HTP_A16_W8Static) { std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; @@ -158,6 +153,40 @@ TEST_F(QnnHTPBackendTests, MatMulOp_HTP_A16_W8Static) { 7e-3f); } +// Test QDQ MatMul with uint16 activation uint16 weights, both dynamic +// Inaccuracy detected for output 'output_0', element 1. +// Output quant params: scale=0.0015259021893143654, zero_point=0. +// Expected val: 40 +// QNN QDQ val: 39.681087493896484 (err 0.31891250610351562) +// CPU QDQ val: 39.99847412109375 (err 0.00152587890625) +TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_HTP_A16_W16Dynamic) { + std::vector input0_data = {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}; + std::vector input1_data = {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}; + RunQDQMatMulOpOpTest(TestInputDef({2, 3}, false, input0_data), + TestInputDef({3, 2}, false, input1_data), + ExpectedEPNodeAssignment::All, + 18, + true, // Use com.microsoft Q/DQ ops + 7e-3f); +} + +// Test QDQ MatMul with uint16 activation uint16 weights, both dynamic +// Inaccuracy detected for output 'output_0', element 1. +// Output quant params: scale=0.71908456087112427, zero_point=1. +// Expected val: 46848.41015625 +// QNN QDQ val: 46844.04296875 (err 4.3671875) +// CPU QDQ val: 46848.359375 (err 0.05078125) +TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_HTP_A16_W16DynamicLarge) { + std::vector input0_data = GetFloatDataInRange(-10.0f, 10.0f, 12 * 96 * 512); + std::vector input1_data = GetFloatDataInRange(-10.0f, 10.0f, 12 * 96 * 512); + RunQDQMatMulOpOpTest(TestInputDef({1, 12, 96, 512}, false, input0_data), + TestInputDef({1, 12, 512, 96}, false, input1_data), + ExpectedEPNodeAssignment::All, + 18, + true, // Use com.microsoft Q/DQ ops + 7e-3f); +} + // Test 16-bit QDQ MatMul with static weights // TODO: Inaccuracy detected for output 'output', element 0. // Output quant params: scale=0.0015259021893143654, zero_point=0. diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py deleted file mode 100644 index 4cf2e5d7f7588..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py +++ /dev/null @@ -1,1026 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import copy -import os -import unittest - -import numpy as np -import onnx -import torch -import torch.nn as nn -import torch.nn.functional as F -from helper import get_name -from numpy.testing import assert_allclose -from torchvision import datasets, transforms - -import onnxruntime -from onnxruntime.capi.ort_trainer import ( - IODescription, - LossScaler, - ModelDescription, - ORTTrainer, - generate_sample, - load_checkpoint, - save_checkpoint, -) - -SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) - - -def ort_trainer_learning_rate_description(): - return IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ) - - -def remove_extra_info(model_desc): - simple_model_desc = copy.deepcopy(model_desc) - for input_desc in simple_model_desc.inputs_: - input_desc.dtype_ = None - input_desc.num_classes_ = None - for output_desc in simple_model_desc.outputs_: - output_desc.dtype_ = None - output_desc.num_classes_ = None - return simple_model_desc - - -def bert_model_description(): - vocab_size = 30528 - input_ids_desc = IODescription( - "input_ids", - ["batch", "max_seq_len_in_batch"], - torch.int64, - num_classes=vocab_size, - ) - segment_ids_desc = IODescription("segment_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2) - input_mask_desc = IODescription("input_mask", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2) - masked_lm_labels_desc = IODescription( - "masked_lm_labels", - ["batch", "max_seq_len_in_batch"], - torch.int64, - num_classes=vocab_size, - ) - next_sentence_labels_desc = IODescription( - "next_sentence_labels", - [ - "batch", - ], - torch.int64, - num_classes=2, - ) - loss_desc = IODescription("loss", [], torch.float32) - - return ModelDescription( - [ - input_ids_desc, - segment_ids_desc, - input_mask_desc, - masked_lm_labels_desc, - next_sentence_labels_desc, - ], - [loss_desc], - ) - - -def map_optimizer_attributes(name): - no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] - no_decay = any(no_decay_key in name for no_decay_key in no_decay_keys) - if no_decay: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} - else: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6} - - -def generate_sample_batch(desc, batch_size, device): - desc_ = copy.deepcopy(desc) - desc_.shape_[0] = batch_size - sample = generate_sample(desc_, device) - return sample - - -def create_ort_trainer( - gradient_accumulation_steps, - use_mixed_precision, - allreduce_post_accumulation, - use_simple_model_desc=True, - loss_scaler=None, - deepspeed_zero_stage=0, -): - model_desc = bert_model_description() - simple_model_desc = remove_extra_info(model_desc) if use_simple_model_desc else model_desc - learning_rate_description = ort_trainer_learning_rate_description() - device = torch.device("cuda", 0) - - onnx_model = onnx.load(get_name("bert_toy_postprocessed.onnx")) - - model = ORTTrainer( - onnx_model, - None, - simple_model_desc, - "LambOptimizer", - map_optimizer_attributes, - learning_rate_description, - device, - gradient_accumulation_steps=gradient_accumulation_steps, - world_rank=0, - world_size=1, - loss_scaler=loss_scaler, - use_mixed_precision=use_mixed_precision, - allreduce_post_accumulation=allreduce_post_accumulation, - deepspeed_zero_stage=deepspeed_zero_stage, - ) - - return model, model_desc, device - - -def run_bert_training_test( - gradient_accumulation_steps, - use_mixed_precision, - allreduce_post_accumulation, - use_simple_model_desc=True, - use_internel_loss_scale=False, -): - torch.manual_seed(1) - onnxruntime.set_seed(1) - - loss_scaler = LossScaler("ort_test_input_loss_scalar", True) if use_internel_loss_scale else None - - model, model_desc, device = create_ort_trainer( - gradient_accumulation_steps, - use_mixed_precision, - allreduce_post_accumulation, - use_simple_model_desc, - loss_scaler, - ) - - if loss_scaler is None: - loss_scaler = LossScaler(model.loss_scale_input_name, True) - - input_ids_batches = [] - segment_ids_batches = [] - input_mask_batches = [] - masked_lm_labels_batches = [] - next_sentence_labels_batches = [] - batch_size = 16 - num_batches = 8 - for _batch in range(num_batches): - input_ids_batches = [ - *input_ids_batches, - generate_sample_batch(model_desc.inputs_[0], batch_size, device), - ] - segment_ids_batches = [ - *segment_ids_batches, - generate_sample_batch(model_desc.inputs_[1], batch_size, device), - ] - input_mask_batches = [ - *input_mask_batches, - generate_sample_batch(model_desc.inputs_[2], batch_size, device), - ] - masked_lm_labels_batches = [ - *masked_lm_labels_batches, - generate_sample_batch(model_desc.inputs_[3], batch_size, device), - ] - next_sentence_labels_batches = [ - *next_sentence_labels_batches, - generate_sample_batch(model_desc.inputs_[4], batch_size, device), - ] - - lr_batch_list = [ - 0.0000000e00, - 4.6012269e-07, - 9.2024538e-07, - 1.3803681e-06, - 1.8404908e-06, - 2.3006135e-06, - 2.7607362e-06, - 3.2208588e-06, - 3.6809815e-06, - ] - - actual_losses = [] - actual_all_finites = [] - - for batch_count in range(num_batches): - input_ids = generate_sample_batch(model_desc.inputs_[0], batch_size, device) - segment_ids = generate_sample_batch(model_desc.inputs_[1], batch_size, device) - input_mask = generate_sample_batch(model_desc.inputs_[2], batch_size, device) - masked_lm_labels = generate_sample_batch(model_desc.inputs_[3], batch_size, device) - next_sentence_labels = generate_sample_batch(model_desc.inputs_[4], batch_size, device) - lr = lr_batch_list[batch_count] - - learning_rate = torch.tensor([lr]).to(device) - training_args = [ - input_ids, - segment_ids, - input_mask, - masked_lm_labels, - next_sentence_labels, - learning_rate, - ] - if use_mixed_precision: - if not use_internel_loss_scale: - loss_scale = torch.tensor([loss_scaler.loss_scale_]).to(device) - training_args.append(loss_scale) - actual_loss = model.train_step(*training_args) - if isinstance(actual_loss, (list, tuple)): - assert len(actual_loss) == 2 - actual_loss, actual_all_finite = actual_loss - if not use_internel_loss_scale: - loss_scaler.update_loss_scale(actual_all_finite.item()) - actual_all_finites = [ - *actual_all_finites, - actual_all_finite.cpu().numpy().item(0), - ] - - actual_losses = [*actual_losses, actual_loss.cpu().numpy().item(0)] - else: - loss = model(*training_args) - actual_losses = [*actual_losses, loss.cpu().numpy().item(0)] - - if batch_count == num_batches - 1: - # test eval_step api with fetches at the end of the training. - # if eval_step is called during the training, it will affect the actual training loss (training session is stateful). - eval_loss = model.eval_step( - input_ids, - segment_ids, - input_mask, - masked_lm_labels, - next_sentence_labels, - fetches=["loss"], - ) - eval_loss = eval_loss.cpu().numpy().item(0) - - # If using internal loss scale, all_finites are handled internally too. - if use_mixed_precision and not use_internel_loss_scale: - return actual_losses, actual_all_finites, eval_loss - else: - return actual_losses, eval_loss - - -class MNISTWrapper: - class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - self.register_buffer("bias_buffer", torch.tensor(1e-6)) - - def forward(self, x): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - out = torch.add(out, self.bias_buffer.to(out.dtype)) - return out - - class NeuralNetWithLoss(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, x, target): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - return F.nll_loss(F.log_softmax(out, dim=1), target), out - - def my_loss(x, target): # noqa: N805 - return F.nll_loss(F.log_softmax(x, dim=1), target) - - def train_with_trainer(self, learningRate, trainer, device, train_loader, epoch): - actual_losses = [] - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - args_log_interval = 100 - if batch_idx % args_log_interval == 0: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, - batch_idx * len(data), - len(train_loader.dataset), - 100.0 * batch_idx / len(train_loader), - loss.item(), - ) - ) - actual_losses = [*actual_losses, loss.cpu().numpy().item()] - - return actual_losses - - # TODO: comple this once ORT training can do evaluation. - def test_with_trainer(self, trainer, device, test_loader): - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - output = F.log_softmax(trainer.eval_step((data), fetches=["probability"]), dim=1) - test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print( - "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( - test_loss, - correct, - len(test_loader.dataset), - 100.0 * correct / len(test_loader.dataset), - ) - ) - - return test_loss, correct / len(test_loader.dataset) - - def mnist_model_description(): - input_desc = IODescription("input1", ["batch", 784], torch.float32) - label_desc = IODescription( - "label", - [ - "batch", - ], - torch.int64, - num_classes=10, - ) - loss_desc = IODescription("loss", [], torch.float32) - probability_desc = IODescription("probability", ["batch", 10], torch.float32) - return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc]) - - def get_loaders(self): - args_batch_size = 64 - args_test_batch_size = 1000 - - kwargs = {"num_workers": 0, "pin_memory": True} - # set shuffle to False to get deterministic data set among different torch version - train_loader = torch.utils.data.DataLoader( - datasets.MNIST( - os.path.join(SCRIPT_DIR, "data"), - train=True, - download=True, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args_batch_size, - shuffle=False, - **kwargs, - ) - test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - os.path.join(SCRIPT_DIR, "data"), - train=False, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args_test_batch_size, - shuffle=False, - **kwargs, - ) - - return train_loader, test_loader - - def get_model(self): - input_size = 784 - hidden_size = 500 - num_classes = 10 - - # warning: changes the pytorch random generator state - model = MNISTWrapper.NeuralNet(input_size, hidden_size, num_classes) - model_desc = MNISTWrapper.mnist_model_description() - return model, model_desc - - def get_model_with_internal_loss(self): - input_size = 784 - hidden_size = 500 - num_classes = 10 - - # warning: changes the pytorch random generator state - model = MNISTWrapper.NeuralNetWithLoss(input_size, hidden_size, num_classes) - model_desc = MNISTWrapper.mnist_model_description() - return model, model_desc - - def get_trainer( - self, - model, - model_desc, - device, - onnx_opset_ver=12, - frozen_weights=[], # noqa: B006 - internal_loss_fn=False, - get_lr_this_step=None, - optimizer="SGDOptimizer", - ): - loss_fn = MNISTWrapper.my_loss if not internal_loss_fn else None - return ORTTrainer( - model, - loss_fn, - model_desc, - optimizer, - None, - IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ), - device, - _opset_version=onnx_opset_ver, - frozen_weights=frozen_weights, - get_lr_this_step=get_lr_this_step, - ) - - -class TestOrtTrainer(unittest.TestCase): - def run_mnist_training_and_testing(onnx_opset_ver): # noqa: N805 - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - trainer = mnist.get_trainer(model, model_desc, device, onnx_opset_ver=onnx_opset_ver) - - learningRate = 0.01 # noqa: N806 - args_epochs = 2 - expected_losses = [ - 2.312044143676758, - 0.8018650412559509, - 0.5819257497787476, - 0.47025489807128906, - 0.35800155997276306, - 0.41124576330184937, - 0.2731882333755493, - 0.4201386570930481, - 0.39458805322647095, - 0.38380366563796997, - 0.2722422480583191, - 0.24230478703975677, - 0.23505745828151703, - 0.33442264795303345, - 0.21140924096107483, - 0.31545233726501465, - 0.18556523323059082, - 0.3453553020954132, - 0.29598352313041687, - 0.3595045208930969, - ] - - expected_test_losses = [0.3145490005493164, 0.256188737487793] - expected_test_accuracies = [0.9075, 0.9265] - - actual_losses = [] - actual_test_losses, actual_accuracies = [], [] - for epoch in range(1, args_epochs + 1): - actual_losses = [ - *actual_losses, - *mnist.train_with_trainer(learningRate, trainer, device, train_loader, epoch), - ] - - test_loss, accuracy = mnist.test_with_trainer(trainer, device, test_loader) - actual_test_losses = [*actual_test_losses, test_loss] - actual_accuracies = [*actual_accuracies, accuracy] - - # if you update outcomes, also do so for resume from checkpoint test - # args_checkpoint_epoch = 1 - # if epoch == args_checkpoint_epoch: - # state = {'rng_state': torch.get_rng_state(), 'model': trainer.state_dict()} - # torch.save(state, get_name("ckpt_mnist.pt")) - - print("actual_losses=", actual_losses) - print("actual_test_losses=", actual_test_losses) - print("actual_accuracies=", actual_accuracies) - - # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs - # import pdb; pdb.set_trace() - rtol = 1e-03 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_test_losses, - actual_test_losses, - rtol=rtol, - err_msg="test loss mismatch", - ) - assert_allclose( - expected_test_accuracies, - actual_accuracies, - rtol=rtol, - err_msg="test accuracy mismatch", - ) - - def test_mnist_training_and_testing_opset12(self): - TestOrtTrainer.run_mnist_training_and_testing(onnx_opset_ver=12) - - def test_mnist_resume_training_and_testing(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - learningRate = 0.01 # noqa: N806 - args_epochs = 2 - args_checkpoint_epoch = 1 - # should match those in test without checkpointing - expected_losses = [ - 0.26509523391723633, - 0.24135658144950867, - 0.2397943139076233, - 0.3351520597934723, - 0.20998981595039368, - 0.31488314270973206, - 0.18481917679309845, - 0.34727591276168823, - 0.2971782684326172, - 0.3609251379966736, - ] - - expected_test_losses = [0.25632242965698243] - expected_test_accuracies = [0.9264] - - actual_losses = [] - actual_test_losses, actual_accuracies = [], [] - - # restore from checkpoint - resume_trainer = mnist.get_trainer(model, model_desc, device) - checkpoint = torch.load(get_name("ckpt_mnist.pt"), map_location="cpu") - torch.set_rng_state(checkpoint["rng_state"]) - resume_trainer.load_state_dict(checkpoint["model"], strict=True) - - # continue .. - for epoch in range(args_checkpoint_epoch + 1, args_epochs + 1): - actual_losses = [ - *actual_losses, - *mnist.train_with_trainer(learningRate, resume_trainer, device, train_loader, epoch), - ] - - test_loss, accuracy = mnist.test_with_trainer(resume_trainer, device, test_loader) - actual_test_losses = [*actual_test_losses, test_loss] - actual_accuracies = [*actual_accuracies, accuracy] - - print("actual_losses=", actual_losses) - print("actual_test_losses=", actual_test_losses) - print("actual_accuracies=", actual_accuracies) - - # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs - # import pdb; pdb.set_trace() - rtol = 1e-03 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_test_losses, - actual_test_losses, - rtol=rtol, - err_msg="test loss mismatch", - ) - assert_allclose( - expected_test_accuracies, - actual_accuracies, - rtol=rtol, - err_msg="test accuracy mismatch", - ) - - def test_mnist_state_dict(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device) - state_dict = trainer.state_dict() - assert state_dict == {} - - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - state_dict = trainer.state_dict() - assert state_dict.keys() == { - "fc1.bias", - "fc1.weight", - "fc2.bias", - "fc2.weight", - "bias_buffer", - } - - def test_mnist_save_as_onnx(self): - torch.manual_seed(1) - device = torch.device("cuda") - onnx_file_name = "mnist.onnx" - if os.path.exists(onnx_file_name): - os.remove(onnx_file_name) - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device) - trainer.save_as_onnx(onnx_file_name) - assert not os.path.exists(onnx_file_name) - - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - trainer.save_as_onnx(onnx_file_name) - assert os.path.exists(onnx_file_name) - - def test_mnist_device(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - for model_device in [torch.device("cpu"), torch.device("cuda")]: - model.to(model_device) - trainer = mnist.get_trainer(model, model_desc, device) - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - def test_mnist_initializer_names(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device) - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - assert ({n.name for n in trainer.onnx_model_.graph.initializer} - {"bias_buffer"}) == { - n for n, t in model.named_parameters() - } - - def test_mnist_initializer_names_with_internal_loss(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model_with_internal_loss() - - def get_lr_this_step(global_step): - learningRate = 0.02 # noqa: N806 - return torch.tensor([learningRate]) - - trainer = mnist.get_trainer( - model, - model_desc, - device, - internal_loss_fn=True, - get_lr_this_step=get_lr_this_step, - ) - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target) - - assert {n.name for n in trainer.onnx_model_.graph.initializer} == {n for n, t in model.named_parameters()} - - def test_mnist_frozen_weight(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device, frozen_weights=["fc1.weight"]) - - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - fc1_trainstep_1 = trainer.state_dict()["fc1.weight"] - fc2_trainstep_1 = trainer.state_dict()["fc2.weight"] - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - fc1_trainstep_2 = trainer.state_dict()["fc1.weight"] - fc2_trainstep_2 = trainer.state_dict()["fc2.weight"] - assert np.array_equal(fc1_trainstep_1, fc1_trainstep_2) and not np.array_equal(fc2_trainstep_1, fc2_trainstep_2) - - def test_mnist_torch_buffer(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device) - - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - fc1_trainstep_1 = trainer.state_dict()["fc1.weight"] - bias_buffer_trainstep_1 = trainer.state_dict()["bias_buffer"] - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - fc1_trainstep_2 = trainer.state_dict()["fc1.weight"] - bias_buffer_trainstep_2 = trainer.state_dict()["bias_buffer"] - assert not np.array_equal(fc1_trainstep_1, fc1_trainstep_2) and np.array_equal( - bias_buffer_trainstep_1, bias_buffer_trainstep_2 - ) - - def test_mnist_frozen_weight_checkpoint(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device, frozen_weights=["fc1.weight"]) - - learningRate = 0.02 # noqa: N806 - - # do one train step - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - # do one eval step - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.eval_step(data, target) - - # save checkpoint, load model and compare - state_dict = trainer.state_dict() - - new_model, _ = mnist.get_model() - trainer = mnist.get_trainer(new_model, model_desc, device, frozen_weights=["fc1.weight"]) - trainer.load_state_dict(state_dict) - - ckpt_loss, _ = trainer.eval_step(data, target) - assert loss == ckpt_loss - - loaded_state_dict = trainer.state_dict() - assert state_dict.keys() == loaded_state_dict.keys() - - def test_mnist_training_checkpoint(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer( - model, - model_desc, - device, - optimizer="LambOptimizer", - frozen_weights=["fc1.weight"], - ) - - learningRate = 0.02 # noqa: N806 - - # do 5 train step - for _i in range(5): - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - # do one eval step - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.eval_step(data, target) - - # save checkpoint, load model and compare - state_dict = trainer.state_dict() - - new_model, _ = mnist.get_model() - trainer = mnist.get_trainer( - new_model, - model_desc, - device, - optimizer="LambOptimizer", - frozen_weights=["fc1.weight"], - ) - trainer.load_state_dict(state_dict) - - ckpt_loss, _ = trainer.eval_step(data, target) - assert loss == ckpt_loss - - loaded_state_dict = trainer.state_dict() - assert state_dict.keys() == loaded_state_dict.keys() - for key in state_dict: - assert np.array_equal(state_dict[key], loaded_state_dict[key]) - - def test_bert_training_basic(self): - expected_losses = [ - 11.027887, - 11.108191, - 11.055356, - 11.040912, - 10.960277, - 11.02691, - 11.082471, - 10.920979, - ] - expected_eval_loss = [10.958977] - actual_losses, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=1, - use_mixed_precision=False, - allreduce_post_accumulation=False, - ) - - # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs - # print('losses expected: ', expected_losses) - # print('losses actual: ', actual_losses) - # print('eval_loss expected: ', expected_eval_loss) - # print('eval_loss actual: ', actual_eval_loss) - # import pdb; pdb.set_trace() - - rtol = 1e-03 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - def test_bert_training_gradient_accumulation(self): - expected_losses = [ - 11.027887, - 11.108191, - 11.055354, - 11.040904, - 10.960266, - 11.026897, - 11.082475, - 10.920998, - ] - expected_eval_loss = [10.958998] - - actual_losses, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=4, - use_mixed_precision=False, - allreduce_post_accumulation=False, - ) - - # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs - # print('losses expected: ', expected_losses) - # print('losses actual: ', actual_losses) - # print('eval_loss expected: ', expected_eval_loss) - # print('eval_loss actual: ', actual_eval_loss) - # import pdb; pdb.set_trace() - - rtol = 1e-03 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - def test_bert_checkpointing_basic(self): - model, _, _ = create_ort_trainer( - gradient_accumulation_steps=1, - use_mixed_precision=False, - allreduce_post_accumulation=True, - use_simple_model_desc=True, - loss_scaler=None, - ) - sd = model.state_dict() - - # modify one of the default values - sd["bert.encoder.layer.0.attention.output.LayerNorm.weight"] += 1 - model.load_state_dict(sd) - - ckpt_dir = "testdata" - save_checkpoint(model, ckpt_dir, "bert_toy_save_test") - del model - - # create new model - model2, _, _ = create_ort_trainer( - gradient_accumulation_steps=1, - use_mixed_precision=False, - allreduce_post_accumulation=True, - use_simple_model_desc=True, - loss_scaler=None, - ) - - # load changed checkpoint - load_checkpoint(model2, ckpt_dir, "bert_toy_save_test") - loaded_sd = model2.state_dict() - - for k, v in loaded_sd.items(): - assert torch.all(torch.eq(v, sd[k])) - - def test_wrap_model_loss_fn_state_dict(self): - torch.manual_seed(1) - device = torch.device("cuda") - - class LinearModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 4) - - def forward(self, y=None, x=None): - if y is not None: - return self.linear(x) + y - else: - return self.linear(x) + torch.ones(2, 4) - - pt_model = LinearModel() - data = torch.randn(2, 2) - label = torch.tensor([0, 1], dtype=torch.int64) - input_desc = IODescription("x", [2, 2], torch.float32) - label_desc = IODescription( - "label", - [ - 2, - ], - torch.int64, - num_classes=4, - ) - output_desc = IODescription("output", [2, 4], torch.float32) - loss_desc = IODescription("loss", [], torch.float32) - model_desc = ModelDescription([input_desc, label_desc], [loss_desc, output_desc]) - - def loss_fn(x, label): - return F.nll_loss(F.log_softmax(x, dim=1), label) - - def get_lr_this_step(global_step): - learningRate = 0.02 # noqa: N806 - return torch.tensor([learningRate]) - - ort_trainer = ORTTrainer( - pt_model, - loss_fn, - model_desc, - "SGDOptimizer", - None, - IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ), - device, - get_lr_this_step=get_lr_this_step, - ) - ort_trainer.train_step(x=data, label=label) - state_dict = ort_trainer.state_dict() - assert state_dict.keys() == {"linear.bias", "linear.weight"} - - -if __name__ == "__main__": - unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py deleted file mode 100644 index 3b994e6f26710..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import unittest - -from numpy.testing import assert_allclose, assert_array_equal -from onnxruntime_test_ort_trainer import run_bert_training_test - - -class TestOrtTrainer(unittest.TestCase): - def test_bert_training_mixed_precision(self): - expected_losses = [ - 11.034248352050781, - 11.125300407409668, - 11.006105422973633, - 11.047048568725586, - 11.027417182922363, - 11.015759468078613, - 11.060905456542969, - 10.971782684326172, - ] - expected_all_finites = [True, True, True, True, True, True, True, True] - expected_eval_loss = [10.959012985229492] - actual_losses, actual_all_finites, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=1, - use_mixed_precision=True, - allreduce_post_accumulation=False, - use_simple_model_desc=False, - ) - - rtol = 1e-02 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - def test_bert_training_mixed_precision_internal_loss_scale(self): - expected_losses = [ - 11.034248352050781, - 11.125300407409668, - 11.006105422973633, - 11.047048568725586, - 11.027417182922363, - 11.015759468078613, - 11.060905456542969, - 10.971782684326172, - ] - expected_eval_loss = [10.959012985229492] - actual_losses, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=1, - use_mixed_precision=True, - allreduce_post_accumulation=False, - use_simple_model_desc=False, - use_internel_loss_scale=True, - ) - - rtol = 1e-02 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - def test_bert_training_gradient_accumulation_mixed_precision(self): - expected_losses = [ - 11.034248352050781, - 11.125300407409668, - 11.006077766418457, - 11.047025680541992, - 11.027434349060059, - 11.0156831741333, - 11.060973167419434, - 10.971841812133789, - ] - expected_all_finites = [True, True] - expected_eval_loss = [10.95903205871582] - actual_losses, actual_all_finites, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=4, - use_mixed_precision=True, - allreduce_post_accumulation=False, - use_simple_model_desc=False, - ) - - rtol = 1e-02 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - -if __name__ == "__main__": - unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py b/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py deleted file mode 100644 index 540f39b797bdb..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import unittest - -import torch -import torch.nn as nn -from numpy.testing import assert_allclose -from onnxruntime_test_ort_trainer import map_optimizer_attributes, ort_trainer_learning_rate_description -from onnxruntime_test_training_unittest_utils import process_dropout - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer - - -class TestTrainingDropout(unittest.TestCase): - def setUp(self): - torch.manual_seed(1) - onnxruntime.set_seed(1) - - @unittest.skip( - "Temporarily disable this test. The graph below will trigger ORT to " - "sort backward graph before forward graph which gives incorrect result. " - "https://github.com/microsoft/onnxruntime/issues/16801" - ) - def test_training_and_eval_dropout(self): - class TwoDropoutNet(nn.Module): - def __init__(self, drop_prb_1, drop_prb_2, dim_size): - super().__init__() - self.drop_1 = nn.Dropout(drop_prb_1) - self.drop_2 = nn.Dropout(drop_prb_2) - self.weight_1 = torch.nn.Parameter(torch.zeros(dim_size, dtype=torch.float32)) - - def forward(self, x): - x = x + self.weight_1 - x = self.drop_1(x) - x = self.drop_2(x) - output = x - return output[0] - - dim_size = 3 - device = torch.device("cuda", 0) - # This will drop all values, therefore expecting all 0 in output tensor - model = TwoDropoutNet(0.999, 0.999, dim_size) - input_desc = IODescription("input", [dim_size], torch.float32) - output_desc = IODescription("output", [], torch.float32) - model_desc = ModelDescription([input_desc], [output_desc]) - lr_desc = ort_trainer_learning_rate_description() - model = ORTTrainer( - model, - None, - model_desc, - "LambOptimizer", - map_optimizer_attributes, - lr_desc, - device, - postprocess_model=process_dropout, - world_rank=0, - world_size=1, - ) - input = torch.ones(dim_size, dtype=torch.float32).to(device) - expected_training_output = [0.0] - expected_eval_output = [1.0] - learning_rate = torch.tensor([1.0000000e00]).to(device) - input_args = [input, learning_rate] - train_output = model.train_step(*input_args) - - rtol = 1e-04 - assert_allclose( - expected_training_output, - train_output.item(), - rtol=rtol, - err_msg="dropout training loss mismatch", - ) - - eval_output = model.eval_step(input) - assert_allclose( - expected_eval_output, - eval_output.item(), - rtol=rtol, - err_msg="dropout eval loss mismatch", - ) - - # Do another train step to make sure it's using original ratios - train_output_2 = model.train_step(*input_args) - assert_allclose( - expected_training_output, - train_output_2.item(), - rtol=rtol, - err_msg="dropout training loss 2 mismatch", - ) - - -if __name__ == "__main__": - unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py b/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py deleted file mode 100644 index 3d3feca06a99b..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py +++ /dev/null @@ -1,56 +0,0 @@ -import numpy as np -from onnx import numpy_helper - - -def get_node_index(model, node): - i = 0 - while i < len(model.graph.node): - if model.graph.node[i] == node: - break - i += 1 - return i if i < len(model.graph.node) else None - - -def add_const(model, name, output, t_value=None, f_value=None): - const_node = model.graph.node.add() - const_node.op_type = "Constant" - const_node.name = name - const_node.output.extend([output]) - attr = const_node.attribute.add() - attr.name = "value" - if t_value is not None: - attr.type = 4 - attr.t.CopyFrom(t_value) - else: - attr.type = 1 - attr.f = f_value - return const_node - - -def process_dropout(model): - dropouts = [] - index = 0 - for node in model.graph.node: - if node.op_type == "Dropout": - new_dropout = model.graph.node.add() - new_dropout.op_type = "TrainableDropout" - new_dropout.name = "TrainableDropout_%d" % index - # make ratio node - ratio = np.asarray([node.attribute[0].f], dtype=np.float32) - print(ratio.shape) - ratio_value = numpy_helper.from_array(ratio) - ratio_node = add_const( - model, - "dropout_node_ratio_%d" % index, - "dropout_node_ratio_%d" % index, - t_value=ratio_value, - ) - print(ratio_node) - new_dropout.input.extend([node.input[0], ratio_node.output[0]]) - new_dropout.output.extend(node.output) - dropouts.append(get_node_index(model, node)) - index += 1 - dropouts.sort(reverse=True) - for d in dropouts: - del model.graph.node[d] - model.opset_import[0].version = 10 diff --git a/onnxruntime/test/python/transformers/conformer_model_generator.py b/onnxruntime/test/python/transformers/conformer_model_generator.py new file mode 100644 index 0000000000000..71e4f2b63cf4f --- /dev/null +++ b/onnxruntime/test/python/transformers/conformer_model_generator.py @@ -0,0 +1,543 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from typing import List + +import numpy as np +import onnx +from bert_model_generator import float_tensor +from onnx import TensorProto, helper, numpy_helper + + +# Adapted from bert_model_generator.py +def get_tensor_and_weight(name: str, shape: List[int], random=False, zeros=False): + low = 0.0 + high = 1.0 + total_elements = 1 + for x in shape: + total_elements *= x + weights = ( + [np.random.uniform(low, high) for _ in range(total_elements)] + if random + else [0.0] * total_elements + if zeros + else [1.0] * total_elements + ) + return helper.make_tensor(name, TensorProto.FLOAT, shape, weights), weights + + +def create_conformer_attention( + hidden_size=512, + num_heads=8, + epsilon=0.000009999999747378752, + add_before_layernorm=False, + fused=False, +): + # Get head size and ensure head size is an integer + assert hidden_size % num_heads == 0 + head_size = hidden_size // num_heads + + # Construct input and output nodes + inputs = [ + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["batch_size", 8, 512]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, ["batch_size", 8, 512]), + helper.make_tensor_value_info("inp_cache_k", TensorProto.FLOAT, [24, "batch_size", 8, 72, head_size]), + helper.make_tensor_value_info("inp_cache_v", TensorProto.FLOAT, [24, "batch_size", 8, 72, head_size]), + ] + outputs = [ + helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["batch_size", 8, hidden_size]), + helper.make_tensor_value_info("output_1", TensorProto.FLOAT, ["batch_size", 8, 512]), + helper.make_tensor_value_info("oup_cache_k", TensorProto.FLOAT, ["batch_size", 8, 80, 64]), + helper.make_tensor_value_info("oup_cache_v", TensorProto.FLOAT, ["batch_size", 8, 80, 64]), + ] + nodes = [] + + # Create layernorm (Add + LayerNorm or SkipLayerNorm) + if add_before_layernorm: + nodes.extend( + [ + helper.make_node( + "Add", ["input_0", "input_1"], ["layernorm_output_to_skiplayernorm"], "add_before_layernorm" + ), + helper.make_node( + "LayerNormalization", + ["layernorm_output_to_skiplayernorm", "layernorm_weight", "layernorm_bias"], + ["layernorm_add_output_to_matmul"], + "layernorm", + epsilon=epsilon, + ), + ] + ) + else: + nodes.append( + helper.make_node( + "SkipLayerNormalization", + ["input_0", "input_1", "layernorm_weight", "layernorm_bias"], + ["layernorm_add_output_to_matmul", "", "", "layernorm_add_output_to_skiplayernorm"], + "skiplayernorm", + domain="com.microsoft", + epsilon=epsilon, + ) + ) + + if fused: + fused_q_nodes = [ + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "q_weight"], + ["q_matmul_output"], + "q_path_matmul", + ), + helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), + helper.make_node( + "Reshape", ["q_add_output", "k_attn_heads_output"], ["q_4d_bsnh"], "q_reshape_to_4d", allowzero=0 + ), + helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), + helper.make_node( + "Div", + ["q_4d_bnsh", "q_scale"], + ["q_div_output"], + "q_div_by_sqrt_head_size", + ), + ] + nodes.extend(fused_q_nodes) + nodes.extend( + [ + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "k_weight"], + ["k_matmul_output"], + "k_path_matmul", + ), + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "v_weight"], + ["v_matmul_output"], + "v_path_matmul", + ), + helper.make_node( + "Reshape", ["q_div_output", "position_embed_output"], ["reshape_pos_emb"], "r_pos_emb", allowzero=0 + ), + helper.make_node( + "Transpose", ["reshape_pos_emb"], ["transpose_reshape_pos_emb"], "p_transpose", perm=[1, 0, 2] + ), + helper.make_node( + "MatMul", + ["transpose_reshape_pos_emb", "transpose_reshape_pos_emb"], + ["pos_matmul"], + "pos_embed_matmul", + ), + helper.make_node( + "Transpose", ["pos_matmul"], ["transpose_pos_matmul"], "p_matmul_transpose", perm=[1, 0, 2] + ), + helper.make_node( + "Reshape", + ["transpose_pos_matmul", "position_embed_output"], + ["reshape_position_emb"], + "final_reshape_pos_emb", + allowzero=0, + ), + helper.make_node( + "MultiHeadAttention", + [ + "q_matmul_output", + "k_matmul_output", + "v_matmul_output", + "Attention_0_qkv_bias", + "", + "reshape_position_emb", + "gather_past_k_output", + "gather_past_v_output", + ], + ["attn_output", "oup_cache_k", "oup_cache_v"], + "Attention_0", + domain="com.microsoft", + num_heads=num_heads, + ), + ] + ) + # Create nodes used with qkv concats, reshapes, and transposes + nodes.extend( + [ + helper.make_node("Shape", ["layernorm_add_output_to_matmul"], ["shape_output"], "shape", start=0), + helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), + helper.make_node( + "Mul", + ["gather_0_output", "num_heads_int"], + ["mul_attn_heads_output"], + "mul_num_heads", + ), + helper.make_node( + "Unsqueeze", + ["mul_attn_heads_output", "unsqueeze_axes_input"], + ["unsqueeze_position_embed"], + "unsqueeze_position_embed", + ), + helper.make_node( + "Concat", + ["unsqueeze_position_embed", "neg_one", "head_size"], + ["position_embed_output"], + "position_embed_concat_output", + axis=0, + ), + helper.make_node( + "Unsqueeze", + ["gather_0_output", "unsqueeze_axes_input"], + ["unsqueeze_attn_heads_output"], + "unsqueeze_num_heads", + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], + ["k_attn_heads_output"], + "k_num_heads", + axis=0, + ), + ] + ) + + nodes.extend( + [ + helper.make_node("Gather", ["inp_cache_v", "idx_0"], ["gather_past_v_output"], "gather_past_v", axis=0), + helper.make_node("Gather", ["inp_cache_k", "idx_0"], ["gather_past_k_output"], "gather_past_k", axis=0), + ] + ) + else: + # Create nodes for Q/K/V paths + q_nodes = [ + helper.make_node( + "MatMul", ["layernorm_add_output_to_matmul", "q_weight"], ["q_matmul_output"], "q_path_matmul" + ), + helper.make_node("Add", ["q_bias", "q_matmul_output"], ["q_add_output"], "q_path_add"), + helper.make_node("Reshape", ["q_add_output", "q_attn_heads_output"], ["q_4d_bsnh"], "q_reshape_to_4d"), + helper.make_node("Transpose", ["q_4d_bsnh"], ["q_4d_bnsh"], "q_transpose_to_bnsh", perm=[0, 2, 1, 3]), + helper.make_node( + "Div", + ["q_4d_bnsh", "q_scale"], + ["q_div_output"], + "q_div_by_sqrt_head_size", + ), + ] + k_nodes = [ + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "k_weight"], + ["k_matmul_output"], + "k_path_matmul", + ), + helper.make_node("Add", ["k_bias", "k_matmul_output"], ["k_add_output"], "k_path_add"), + helper.make_node("Reshape", ["k_add_output", "k_attn_heads_output"], ["k_4d_bsnh"], "k_reshape_to_4d"), + helper.make_node("Transpose", ["k_4d_bsnh"], ["k_4d_bnsh"], "k_transpose_to_bnsh", perm=[0, 2, 1, 3]), + helper.make_node( + "Concat", + ["gather_past_k_output", "k_4d_bnsh"], + ["oup_cache_k"], + "concat_past_k_and_curr_k", + axis=2, + ), + helper.make_node( + "Transpose", + ["oup_cache_k"], + ["k_output_transpose"], + "k_transpose_last_two_dims", + perm=[0, 1, 3, 2], + ), + ] + v_nodes = [ + helper.make_node( + "MatMul", + ["layernorm_add_output_to_matmul", "v_weight"], + ["v_matmul_output"], + "v_path_matmul", + ), + helper.make_node("Add", ["v_bias", "v_matmul_output"], ["v_add_output"], "v_path_add"), + helper.make_node("Reshape", ["v_add_output", "v_attn_heads_output"], ["v_4d_bsnh"], "v_reshape_to_4d"), + helper.make_node("Transpose", ["v_4d_bsnh"], ["v_4d_bnsh"], "v_transpose_to_bnsh", perm=[0, 2, 1, 3]), + helper.make_node( + "Concat", + ["gather_past_v_output", "v_4d_bnsh"], + ["oup_cache_v"], + "concat_past_v_and_curr_v", + axis=2, + ), + ] + pos_embed = [ + helper.make_node("Reshape", ["q_div_output", "position_embed_output"], ["reshape_pos_emb"], "r_pos_emb"), + helper.make_node( + "Transpose", ["reshape_pos_emb"], ["transpose_reshape_pos_emb"], "p_transpose", perm=[1, 0, 2] + ), + helper.make_node( + "MatMul", + ["transpose_reshape_pos_emb", "transpose_reshape_pos_emb"], + ["pos_matmul"], + "pos_embed_matmul", + ), + helper.make_node( + "Transpose", ["pos_matmul"], ["transpose_pos_matmul"], "p_matmul_transpose", perm=[1, 0, 2] + ), + helper.make_node( + "Reshape", + ["transpose_pos_matmul", "position_embed_output"], + ["reshape_position_emb"], + "final_reshape_pos_emb", + ), + ] + nodes.extend(q_nodes) + nodes.extend(k_nodes) + nodes.extend(v_nodes) + nodes.extend(pos_embed) + + # Create nodes used with qkv concats, reshapes, and transposes + nodes.extend( + [ + helper.make_node("Shape", ["layernorm_add_output_to_matmul"], ["shape_output"], "shape", start=0), + helper.make_node("Gather", ["shape_output", "idx_0"], ["gather_0_output"], "gather_0", axis=0), + helper.make_node( + "Mul", + ["gather_0_output", "num_heads_int"], + ["mul_attn_heads_output"], + "mul_num_heads", + ), + helper.make_node( + "Unsqueeze", + ["mul_attn_heads_output", "unsqueeze_axes_input"], + ["unsqueeze_position_embed"], + "unsqueeze_position_embed", + ), + helper.make_node( + "Concat", + ["unsqueeze_position_embed", "neg_one", "head_size"], + ["position_embed_output"], + "position_embed_concat_output", + axis=0, + ), + helper.make_node( + "Unsqueeze", + ["gather_0_output", "unsqueeze_axes_input"], + ["unsqueeze_attn_heads_output"], + "unsqueeze_num_heads", + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], + ["q_attn_heads_output"], + "q_num_heads", + axis=0, + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], + ["k_attn_heads_output"], + "k_num_heads", + axis=0, + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size", "q_bsnh_reshape"], + ["v_attn_heads_output"], + "v_num_heads", + axis=0, + ), + helper.make_node( + "Concat", + ["unsqueeze_attn_heads_output", "neg_one", "head_size"], + ["bsd_format"], + axis=0, + ), + helper.make_node( + "Constant", + inputs=[], + outputs=["q_bsnh_reshape"], + value=numpy_helper.from_array( + np.array([0, 0, num_heads, head_size], dtype="int64"), name="const_tensor" + ), + ), + ] + ) + + nodes.extend( + [ + helper.make_node("Gather", ["inp_cache_v", "idx_0"], ["gather_past_v_output"], "gather_past_v", axis=0), + helper.make_node("Gather", ["inp_cache_k", "idx_0"], ["gather_past_k_output"], "gather_past_k", axis=0), + ] + ) + + # Compute Q x K' + nodes.extend( + [ + helper.make_node( + "MatMul", + [ + "q_div_output", + "k_output_transpose", + ], + ["qk_output"], + "matmul_qk", + ) + ] + ) + + # Create nodes for computing softmax(Q x K') x V + nodes.extend( + [ + helper.make_node( + "Add", + [ + "qk_output", + "reshape_position_emb", + ], + ["add_qk_output"], + "add_qk", + ), + helper.make_node( + "Softmax", + ["add_qk_output"], + ["softmax_output"], + "softmax_qk", + axis=2, + ), + helper.make_node( + "MatMul", + ["softmax_output", "oup_cache_v"], + ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], + "matmul_qkv", + ), + helper.make_node( + "Transpose", + ["qkv_output_(num_heads*batch_size,seq_len,head_size)"], + ["qkv_bsnh"], + "transpose_bnsh_to_bsnh", + perm=[0, 2, 1, 3], + ), + helper.make_node("Reshape", ["qkv_bsnh", "bsd_format"], ["attn_output"], "qkv_bsd"), + ] + ) + + # Create final nodes to conclude attention + nodes.append( + helper.make_node( + "MatMul", + ["attn_output", "matmul_after_attn_initializer"], + ["matmul_after_attn_output"], + "matmul_after_attn", + ), + ) + if not fused: + next_sln_inputs = [ + "layernorm_add_output_to_skiplayernorm", + "add_after_attn_output", + "layernorm_weight", + "layernorm_bias", + ] + nodes.extend( + [ + helper.make_node( + "Add", + ["add_after_attn_initializer", "matmul_after_attn_output"], + ["add_after_attn_output"], + "add_after_attn", + ), + helper.make_node( + "SkipLayerNormalization", + next_sln_inputs, + ["output_0", "", "", "output_1"], + "next_skiplayernorm", + domain="com.microsoft", + epsilon=epsilon, + ), + ] + ) + else: + next_sln_inputs = [ + "matmul_after_attn_output", + "layernorm_add_output_to_skiplayernorm", + "layernorm_weight", + "layernorm_bias", + "add_after_attn_initializer", + ] + nodes.append( + helper.make_node( + "SkipLayerNormalization", + next_sln_inputs, + ["output_0", "", "", "output_1"], + "SkipLayerNorm_AddBias_0", + domain="com.microsoft", + epsilon=epsilon, + ) + ) + + # Create initializers + v_weight, v_weight_data = get_tensor_and_weight("v_weight", [hidden_size, hidden_size]) + v_bias, v_bias_data = get_tensor_and_weight("v_bias", [hidden_size]) + q_weight, q_weight_data = get_tensor_and_weight("q_weight", [hidden_size, hidden_size]) + q_bias, q_bias_data = get_tensor_and_weight("q_bias", [hidden_size]) + k_weight, k_weight_data = get_tensor_and_weight("k_weight", [hidden_size, hidden_size]) + k_bias, k_bias_data = get_tensor_and_weight("k_bias", [hidden_size]) + + qkv_bias = helper.make_tensor( + "Attention_0_qkv_bias", + TensorProto.FLOAT, + [3 * hidden_size], + q_bias_data + k_bias_data + v_bias_data, + ) + initializers = [ + float_tensor("layernorm_weight", [hidden_size]), + float_tensor("layernorm_bias", [hidden_size]), + float_tensor("matmul_after_attn_initializer", [hidden_size, hidden_size]), + float_tensor("add_after_attn_initializer", [hidden_size]), + ] + + # Add Q/K/V weight tensors as initializers + if fused: + initializers.extend([q_weight, k_weight, v_weight]) + initializers.extend([q_bias]) + initializers.append(qkv_bias) + initializers.extend( + [ + numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), + numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), + numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), + numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), + numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), + numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), + numpy_helper.from_array(np.array([0, 0, num_heads, head_size], dtype="int64"), name="q_bsnh_reshape"), + ] + ) + else: + initializers.extend([q_weight, k_weight, v_weight]) + + initializers.extend([q_bias, k_bias, v_bias]) + + initializers.extend( + [ + numpy_helper.from_array(np.array(num_heads, dtype="int64"), name="num_heads_int"), + numpy_helper.from_array(np.array([num_heads], dtype="int64"), name="num_heads"), + numpy_helper.from_array(np.array([head_size], dtype="int64"), name="head_size"), + numpy_helper.from_array(np.array([hidden_size], dtype="int64"), name="hidden_size"), + numpy_helper.from_array(np.array(1 / np.sqrt(head_size), dtype="float32"), name="q_scale"), + numpy_helper.from_array(np.array(0, dtype="int64"), name="idx_0"), + numpy_helper.from_array(np.array(1, dtype="int64"), name="idx_1"), + numpy_helper.from_array(np.array([-1], dtype="int64"), name="neg_one"), + numpy_helper.from_array(np.array([0], dtype="int64"), name="unsqueeze_axes_input"), + ] + ) + + # Construct graph + graph = helper.make_graph(nodes, "conformer_self_mha_graph", inputs, outputs, initializers, doc_string="conformer") + opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) + return helper.make_model(graph, opset_imports=(opsetid,)) + + +if __name__ == "__main__": + np.random.seed(2) + num_heads = 8 + hidden_size = 512 + + model = create_conformer_attention(num_heads=num_heads, hidden_size=hidden_size) + onnx.save(model, "conformer_self_mha.onnx") + + model = create_conformer_attention(num_heads=num_heads, hidden_size=hidden_size, fused=True) + onnx.save(model, "./test_data/models/conformer/conformer_self_mha_fused.onnx") diff --git a/onnxruntime/test/python/transformers/test_conformer.py b/onnxruntime/test/python/transformers/test_conformer.py new file mode 100644 index 0000000000000..471ba9756bcf8 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_conformer.py @@ -0,0 +1,69 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import unittest + +import onnx +from conformer_model_generator import create_conformer_attention +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from fusion_options import FusionOptions + from onnx_model import OnnxModel + from optimizer import optimize_model +else: + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.optimizer import optimize_model + + +class TestFusion(unittest.TestCase): + def verify_fusion(self, optimized_model, expected_model_filename): + optimized_model.topological_sort(is_deterministic=True) + + expected_model_path = os.path.join( + os.path.dirname(__file__), "test_data", "models", "conformer", expected_model_filename + ) + print("Expected model path = ", expected_model_path) + expected_model = OnnxModel(onnx.load(expected_model_path)) + expected_model.topological_sort(is_deterministic=True) + + nodes = optimized_model.model.graph.node + self.assertEqual(len(nodes), len(expected_model.model.graph.node)) + + for i in range(len(nodes)): + self.assertEqual(nodes[i], expected_model.model.graph.node[i]) + + for expected_initializer in expected_model.model.graph.initializer: + print("Expected initializer initial = ", expected_initializer.name) + self.assertTrue( + OnnxModel.has_same_value( + optimized_model.get_initializer(expected_initializer.name), expected_initializer + ) + ) + + def test_ct_mha_fusion(self): + num_heads = 8 + hidden_size = 512 + model = create_conformer_attention(num_heads=num_heads, hidden_size=hidden_size, add_before_layernorm=False) + dir = "." + model_path = os.path.join(dir, "conformer_self_mha.onnx") + onnx.save(model, model_path) + options = FusionOptions("conformer") + optimized_model = optimize_model( + model_path, + model_type="conformer", + num_heads=num_heads, + hidden_size=hidden_size, + optimization_options=options, + ) + os.remove(model_path) + self.verify_fusion(optimized_model, "conformer_self_mha_fused.onnx") + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_data/models/conformer/conformer_self_mha_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/conformer/conformer_self_mha_fused.onnx new file mode 100644 index 0000000000000..9d882751db265 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/conformer/conformer_self_mha_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index 99f62ffdb9f53..8a839875de2a2 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -183,7 +183,9 @@ def create_multihead_attention_graph(config): return model.SerializeToString() -def create_group_query_attention_graph_prompt(config, past_kv_format=Formats.BSNH, share_buffer=True): +def create_group_query_attention_graph_prompt( + config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1 +): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length nodes = [ @@ -202,6 +204,7 @@ def create_group_query_attention_graph_prompt(config, past_kv_format=Formats.BSN "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, + local_window_size=local_window_size, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -297,6 +300,26 @@ def create_group_query_attention_graph_prompt(config, past_kv_format=Formats.BSN config.head_size, ], ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + ], + ), ] graph = helper.make_graph( @@ -310,7 +333,9 @@ def create_group_query_attention_graph_prompt(config, past_kv_format=Formats.BSN return model.SerializeToString() -def create_group_query_attention_graph_past(config, past_kv_format=Formats.BSNH, share_buffer=True): +def create_group_query_attention_graph_past( + config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1 +): past_kv_seqlen = config.kv_sequence_length present_kv_seqlen = ( config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length @@ -331,6 +356,7 @@ def create_group_query_attention_graph_past(config, past_kv_format=Formats.BSNH, "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, + local_window_size=local_window_size, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -636,8 +662,12 @@ def mha_func(q, k, v, config): return output -def gqa_prompt_func(q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True): - onnx_model_str = create_group_query_attention_graph_prompt(config, past_kv_format, share_buffer) +def gqa_prompt_func( + q, k, v, config, new_k, new_v, seqlens_k=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True +): + onnx_model_str = create_group_query_attention_graph_prompt( + config, past_kv_format, share_buffer, local_window_size=window_size + ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None past_v = v.clone() if share_buffer else None @@ -706,8 +736,12 @@ def gqa_prompt_func(q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_forma return output, present_k, present_v -def gqa_past_func(q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True): - onnx_model_str = create_group_query_attention_graph_past(config, past_kv_format, share_buffer) +def gqa_past_func( + q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1 +): + onnx_model_str = create_group_query_attention_graph_past( + config, past_kv_format, share_buffer, local_window_size=window_size + ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() past_v = v.clone() @@ -796,6 +830,28 @@ def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_paddi return col_idx > row_idx + sk - sq +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + + def attention_ref( q, k, @@ -805,6 +861,7 @@ def attention_ref( dropout_p=0.0, dropout_mask=None, causal=False, + window_size=(-1, -1), # -1 means infinite window size upcast=True, reorder_ops=False, ): @@ -817,6 +874,8 @@ def attention_ref( key_padding_mask: (batch_size, seqlen_k) dropout_p: float dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast output back to fp16/bf16. reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) @@ -826,6 +885,8 @@ def attention_ref( output: (batch_size, seqlen_q, nheads, head_dim) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ + if causal: + window_size = (window_size[0], 0) dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() @@ -839,12 +900,24 @@ def attention_ref( scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - if causal: - causal_mask = construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device) - scores.masked_fill_(causal_mask, float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) attention = torch.softmax(scores, dim=-1) - if causal: # Some rows are completely masked out so we fill them with zero instead of NaN - attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) if dropout_mask is not None: attention_drop = attention.masked_fill(~dropout_mask, 0.0) @@ -853,7 +926,6 @@ def attention_ref( output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) @@ -957,6 +1029,8 @@ def parity_check_mha( def parity_check_gqa_prompt( config, + causal=False, + local=False, past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, @@ -1007,6 +1081,15 @@ def parity_check_gqa_prompt( requires_grad=False, ) + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(0, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) + # Pytorch to compare k_cache_ref = k.clone() v_cache_ref = v.clone() @@ -1033,14 +1116,18 @@ def parity_check_gqa_prompt( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) + out_ref, _ = attention_ref( + q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_prompt_func(q, k, v, config, new_k, new_v, cache_seqlens, past_format, True) + out, present_k, present_v = gqa_prompt_func( + q, k, v, config, new_k, new_v, cache_seqlens, left_window_size, past_format, True + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1052,6 +1139,10 @@ def parity_check_gqa_prompt( # Compare results print( "KV-buffer", + " causal:", + causal, + " local:", + local, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1080,6 +1171,8 @@ def parity_check_gqa_prompt( def parity_check_gqa_prompt_no_buff( config, + causal=False, + local=False, past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, @@ -1112,6 +1205,15 @@ def parity_check_gqa_prompt_no_buff( requires_grad=False, ) + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(0, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) + # Pytorch to compare k_cache_ref = new_k.clone() v_cache_ref = new_v.clone() @@ -1132,14 +1234,18 @@ def parity_check_gqa_prompt_no_buff( new_mask = brange < cache_seqlens_expanded k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True) + out_ref, _ = attention_ref( + q, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size + ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_prompt_func(q, None, None, config, new_k, new_v, cache_seqlens, past_format, False) + out, present_k, present_v = gqa_prompt_func( + q, None, None, config, new_k, new_v, cache_seqlens, left_window_size, past_format, False + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1179,6 +1285,8 @@ def parity_check_gqa_prompt_no_buff( def parity_check_gqa_past( config, + causal=False, + local=False, past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, @@ -1228,6 +1336,14 @@ def parity_check_gqa_past( dtype=torch.float16, requires_grad=False, ) + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(0, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) # Pytorch to compare k_cache_ref = k.clone() @@ -1253,14 +1369,18 @@ def parity_check_gqa_past( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) + out_ref, _ = attention_ref( + q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, cache_seqlens, past_format, True) + out, present_k, present_v = gqa_past_func( + q, k, v, config, new_k, new_v, cache_seqlens, past_format, True, left_window_size + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1274,6 +1394,10 @@ def parity_check_gqa_past( "KV-buffer", "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", + " causal:", + causal, + " local:", + local, " B:", config.batch_size, " S:", @@ -1300,6 +1424,8 @@ def parity_check_gqa_past( def parity_check_gqa_past_no_buff( config, + causal=False, + local=False, past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, @@ -1351,6 +1477,15 @@ def parity_check_gqa_past_no_buff( requires_grad=False, ) + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(0, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) + # Pytorch to compare k_cache_ref = k.clone() v_cache_ref = v.clone() @@ -1378,14 +1513,18 @@ def parity_check_gqa_past_no_buff( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) + out_ref, _ = attention_ref( + q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, cache_seqlens, past_format, False) + out, present_k, present_v = gqa_past_func( + q, k, v, config, new_k, new_v, cache_seqlens, past_format, False, window_size=left_window_size + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1401,142 +1540,10 @@ def parity_check_gqa_past_no_buff( # Compare results print( "NO buff", - "past kv format:", - "BSNH" if past_format == Formats.BSNH else "BNSH", - " B:", - config.batch_size, - " S:", - config.sequence_length, - " kv S:", - config.kv_sequence_length, - " N:", - config.num_heads, - " kv N:", - config.kv_num_heads, - " h:", - config.head_size, - " Mean Error:", - numpy.mean(numpy.abs(out - out_ref)), - numpy.allclose( - out, - out_ref, - rtol=rtol, - atol=atol, - equal_nan=True, - ), - ) - - -def parity_check_gqa_past_no_buff_no_mask( - config, - past_format=Formats.BSNH, - rtol=1e-3, - atol=1e-3, -): - q = torch.randn( - config.batch_size, - config.sequence_length, - config.num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - k = torch.randn( - config.batch_size, - config.past_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.past_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - v = torch.randn( - config.batch_size, - config.past_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.past_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_k = torch.randn( - config.batch_size, - config.sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_v = torch.randn( - config.batch_size, - config.sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - - # Pytorch to compare - k_cache_ref = k.clone() - v_cache_ref = v.clone() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - k_cache_ref = torch.cat((k_cache_ref, new_k), 1) - v_cache_ref = torch.cat((v_cache_ref, new_v), 1) - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - key_padding_mask = None - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) - out_ref = out_ref.detach().cpu().numpy() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - - # Flash function - out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, past_format, False) - out = torch.squeeze(out, 0) - out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) - out = out.detach().cpu().numpy() - - # Make sure past-present buffer updating correctly - if past_format == Formats.BSNH: - assert numpy.allclose( - present_k, - k_cache_ref.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - ) - assert numpy.allclose( - present_v, - v_cache_ref.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - ) - else: - assert numpy.allclose( - present_k, - k_cache_ref.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - ) - assert numpy.allclose( - present_v, - v_cache_ref.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - ) - - # Compare results - print( - "Unbuffered", + " causal:", + causal, + " local:", + local, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1663,10 +1670,11 @@ def test_gqa_no_past(self): for sq, skv in seqs: for n, n2 in num_h: for h in h_sizes: - for past_kv_format in [Formats.BNSH]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - parity_check_gqa_prompt(config, past_format=past_kv_format) - parity_check_gqa_prompt_no_buff(config, past_format=past_kv_format) + for local in [False, True]: + for past_kv_format in [Formats.BNSH]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + parity_check_gqa_prompt(config, local=local, past_format=past_kv_format) + parity_check_gqa_prompt_no_buff(config, local=local, past_format=past_kv_format) def test_gqa_past(self): if not torch.cuda.is_available(): @@ -1725,24 +1733,25 @@ def test_gqa_past(self): for s, s2 in seqs: for n, n2 in num_h: for h in h_sizes: - for past_kv_format in [Formats.BNSH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + for local in [False, True]: + for past_kv_format in [Formats.BNSH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) if __name__ == "__main__": unittest.main() - # test_gqa = TestGQA() - # test_gqa.test_gqa_past() diff --git a/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py b/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py index b17ae5f69aff5..cf8128e0eebcf 100644 --- a/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py +++ b/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py @@ -261,14 +261,15 @@ def get_eps(self): eps = ["CPUExecutionProvider", "CUDAExecutionProvider"] return list(filter(lambda ep: ep in ort.get_available_providers(), eps)) - def run_ort_ep_tests(self, onnx_graph, inputs_ort, expected_output_bsnh): + def run_ort_ep_tests(self, onnx_graph, inputs_ort, expected_output_bsnh, transposed=False): eps = self.get_eps() for ep in eps: sess = ort.InferenceSession(onnx_graph, providers=[ep]) output_ort = sess.run(None, inputs_ort)[0] - output_ort = output_ort.reshape( - (self.config.batch_size, inputs_ort["input"].shape[1], self.config.num_heads, self.config.head_size) - ) + if not transposed: + output_ort = output_ort.reshape( + (self.config.batch_size, inputs_ort["input"].shape[1], self.config.num_heads, self.config.head_size) + ) # Compare outputs as BxSxNxH self.assertTrue(np.allclose(expected_output_bsnh, output_ort)) @@ -445,6 +446,44 @@ def test_hf_token_rotary_one_pos_id(self): # Compare outputs as BxSxNxH self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy()) + # Bonus test: Prompt step, interleaved = false, pos ids shape = (1), transposed + def test_hf_prompt_rotary_one_pos_id_transposed(self): + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_hf = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_hf) # output is BxNxSxH + + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = torch.tensor([0]) + onnx_graph = self.create_onnx_graph(x_bnsh.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bnsh.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Compare outputs as BxNxSxH + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.detach().cpu().numpy(), transposed=True) + + # Bonus test: Token generation step, interleaved = false, pos ids shape = (1), transposed + def test_hf_token_rotary_one_pos_id_transposed(self): + x_bnsh = torch.randn(self.config.batch_size, self.config.num_heads, 1, self.config.head_size) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_ids = torch.stack([torch.tensor([2]) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_ids) # output is BxSxNxH + + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = torch.tensor([2]) + onnx_graph = self.create_onnx_graph(x_bnsh.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bnsh.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Set tranposed=True to compare outputs as BxSxNxH + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.detach().cpu().numpy(), transposed=True) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/testdata/transform/gh_issue_18338.onnx b/onnxruntime/test/testdata/transform/gh_issue_18338.onnx new file mode 100644 index 0000000000000..afb499a347ec7 Binary files /dev/null and b/onnxruntime/test/testdata/transform/gh_issue_18338.onnx differ diff --git a/onnxruntime/test/testdata/transform/gh_issue_18338.py b/onnxruntime/test/testdata/transform/gh_issue_18338.py new file mode 100644 index 0000000000000..dc5446ac56c09 --- /dev/null +++ b/onnxruntime/test/testdata/transform/gh_issue_18338.py @@ -0,0 +1,859 @@ +import google.protobuf.text_format +import onnx +from numpy import array, float16 + +import onnxruntime as ort + +# Run n times +N = 1 + +onnx_model_text = """ +ir_version: 8 +producer_name: "pytorch" +producer_version: "2.2.0" +graph { + node { + output: "_val_1" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value_ints" + ints: -1 + type: INTS + } + doc_string: "" + } + node { + input: "input_0" + input: "_val_1" + output: "_val_2" + name: "Reshape_1" + op_type: "Reshape" + attribute { + name: "allowzero" + i: 0 + type: INT + } + doc_string: "" + } + node { + input: "_val_2" + output: "_val_3" + name: "_aten_linalg_vector_norm_no_dim_onnx_2" + op_type: "_aten_linalg_vector_norm_no_dim_onnx" + attribute { + name: "keepdim" + i: 0 + type: INT + } + attribute { + name: "ord" + f: 2.0 + type: FLOAT + } + doc_string: "" + domain: "pkg.onnxscript.torch_lib" + } + name: "main_graph" + input { + name: "input_0" + type { + tensor_type { + elem_type: 10 + shape { + } + } + } + } + output { + name: "_val_3" + type { + tensor_type { + elem_type: 10 + shape { + } + } + } + } + value_info { + name: "_val_1" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 1 + } + } + } + } + } + value_info { + name: "_val_2" + type { + tensor_type { + elem_type: 10 + shape { + dim { + dim_value: 1 + } + } + } + } + } +} +opset_import { + domain: "pkg.onnxscript.torch_lib" + version: 1 +} +opset_import { + domain: "" + version: 18 +} +opset_import { + domain: "pkg.onnxscript.torch_lib.common" + version: 1 +} +functions { + name: "_aten_linalg_vector_norm_no_dim_onnx" + input: "self" + output: "result_29" + attribute: "ord" + attribute: "keepdim" + node { + input: "self" + output: "tmp" + name: "n0" + op_type: "Shape" + domain: "" + } + node { + input: "tmp" + output: "self_rank" + name: "n1" + op_type: "Size" + domain: "" + } + node { + output: "int64_0" + name: "n2" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 7 + int64_data: 0 + name: "int64_0" + } + type: TENSOR + } + domain: "" + } + node { + input: "int64_0" + input: "self_rank" + output: "int64_0_cast" + name: "n3" + op_type: "CastLike" + domain: "" + } + node { + input: "self_rank" + input: "int64_0_cast" + output: "cond" + name: "n4" + op_type: "Equal" + domain: "" + } + node { + input: "cond" + output: "self_2" + name: "n5" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + output: "int64_0_1d" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + int64_data: 0 + name: "int64_0_1d" + } + type: TENSOR + } + domain: "" + } + node { + input: "self" + input: "int64_0_1d" + output: "self_0" + name: "n1" + op_type: "Unsqueeze" + domain: "" + } + name: "thenGraph_4" + output { + name: "self_0" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "self" + output: "self_1" + name: "n0" + op_type: "Identity" + domain: "" + } + name: "elseGraph_4" + output { + name: "self_1" + type { + } + } + } + type: GRAPH + } + domain: "" + } + node { + input: "self_2" + output: "self_3" + name: "n6" + op_type: "Abs" + domain: "" + } + node { + output: "ord" + name: "n7" + op_type: "Constant" + attribute { + name: "value_float" + type: FLOAT + ref_attr_name: "ord" + } + domain: "" + } + node { + input: "ord" + output: "ord_4" + name: "n8" + op_type: "Cast" + attribute { + name: "to" + i: 1 + type: INT + } + domain: "" + } + node { + input: "ord_4" + output: "cond_5" + name: "n9" + op_type: "IsInf" + attribute { + name: "detect_negative" + i: 0 + type: INT + } + attribute { + name: "detect_positive" + i: 1 + type: INT + } + domain: "" + } + node { + input: "cond_5" + output: "result_24" + name: "n10" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result" + name: "n0" + op_type: "ReduceMax" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_9" + output { + name: "result" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "ord_4" + output: "cond_6" + name: "n0" + op_type: "IsInf" + attribute { + name: "detect_negative" + i: 1 + type: INT + } + attribute { + name: "detect_positive" + i: 0 + type: INT + } + domain: "" + } + node { + input: "cond_6" + output: "result_23" + name: "n1" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result_7" + name: "n0" + op_type: "ReduceMin" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_11" + output { + name: "result_7" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + output: "const" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 0.0 + name: "const" + } + type: TENSOR + } + domain: "" + } + node { + input: "const" + input: "ord_4" + output: "const_cast" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "ord_4" + input: "const_cast" + output: "cond_8" + name: "n2" + op_type: "Equal" + domain: "" + } + node { + input: "cond_8" + output: "result_22" + name: "n3" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "self_bool" + name: "n0" + op_type: "Cast" + attribute { + name: "to" + i: 9 + type: INT + } + domain: "" + } + node { + input: "self_bool" + input: "self_3" + output: "self_0_1" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "self_0_1" + output: "result_9" + name: "n2" + op_type: "ReduceSum" + attribute { + name: "keepdims" + i: 0 + type: INT + } + domain: "" + } + name: "thenGraph_13" + output { + name: "result_9" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + output: "const_10" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 1.0 + name: "const_10" + } + type: TENSOR + } + domain: "" + } + node { + input: "const_10" + input: "ord_4" + output: "const_10_cast" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "ord_4" + input: "const_10_cast" + output: "cond_11" + name: "n2" + op_type: "Equal" + domain: "" + } + node { + input: "cond_11" + output: "result_21" + name: "n3" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result_12" + name: "n0" + op_type: "ReduceL1" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_18" + output { + name: "result_12" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + output: "const_13" + name: "n0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 2.0 + name: "const_13" + } + type: TENSOR + } + domain: "" + } + node { + input: "const_13" + input: "ord_4" + output: "const_13_cast" + name: "n1" + op_type: "CastLike" + domain: "" + } + node { + input: "ord_4" + input: "const_13_cast" + output: "cond_14" + name: "n2" + op_type: "Equal" + domain: "" + } + node { + input: "cond_14" + output: "result_20" + name: "n3" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "self_3" + output: "result_15" + name: "n0" + op_type: "ReduceL2" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + name: "thenGraph_20" + output { + name: "result_15" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "ord_4" + input: "self_3" + output: "ord_float" + name: "n0" + op_type: "CastLike" + domain: "" + } + node { + input: "self_3" + input: "ord_float" + output: "self_pow" + name: "n1" + op_type: "Pow" + domain: "" + } + node { + input: "self_pow" + output: "tmp_16" + name: "n2" + op_type: "ReduceSum" + attribute { + name: "keepdims" + type: INT + ref_attr_name: "keepdim" + } + domain: "" + } + node { + output: "const_17" + name: "n3" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + float_data: 1.0 + name: "const_17" + } + type: TENSOR + } + domain: "" + } + node { + input: "const_17" + input: "ord_float" + output: "const_17_cast" + name: "n4" + op_type: "CastLike" + domain: "" + } + node { + input: "const_17_cast" + input: "ord_float" + output: "tmp_18" + name: "n5" + op_type: "Div" + domain: "" + } + node { + input: "tmp_16" + input: "tmp_18" + output: "result_19" + name: "n6" + op_type: "Pow" + domain: "" + } + name: "elseGraph_20" + output { + name: "result_19" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_18" + output { + name: "result_20" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_13" + output { + name: "result_21" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_11" + output { + name: "result_22" + type { + } + } + } + type: GRAPH + } + domain: "" + } + name: "elseGraph_9" + output { + name: "result_23" + type { + } + } + } + type: GRAPH + } + domain: "" + } + node { + output: "int64_0_25" + name: "n11" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 7 + int64_data: 0 + name: "int64_0_25" + } + type: TENSOR + } + domain: "" + } + node { + input: "int64_0_25" + input: "self_rank" + output: "int64_0_25_cast" + name: "n12" + op_type: "CastLike" + domain: "" + } + node { + input: "self_rank" + input: "int64_0_25_cast" + output: "cond_26" + name: "n13" + op_type: "Equal" + domain: "" + } + node { + input: "cond_26" + output: "result_29" + name: "n14" + op_type: "If" + attribute { + name: "then_branch" + g { + node { + input: "result_24" + output: "result_27" + name: "n0" + op_type: "Squeeze" + domain: "" + } + name: "thenGraph_27" + output { + name: "result_27" + type { + } + } + } + type: GRAPH + } + attribute { + name: "else_branch" + g { + node { + input: "result_24" + output: "result_28" + name: "n0" + op_type: "Identity" + domain: "" + } + name: "elseGraph_27" + output { + name: "result_28" + type { + } + } + } + type: GRAPH + } + domain: "" + } + opset_import { + domain: "" + version: 18 + } + domain: "pkg.onnxscript.torch_lib" +} +functions { + name: "Rank" + input: "input" + output: "return_val" + node { + input: "input" + output: "tmp" + name: "n0" + op_type: "Shape" + domain: "" + } + node { + input: "tmp" + output: "return_val" + name: "n1" + op_type: "Size" + domain: "" + } + doc_string: "Take the rank of the input tensor." + opset_import { + domain: "" + version: 18 + } + domain: "pkg.onnxscript.torch_lib.common" +} +functions { + name: "IsScalar" + input: "input" + output: "return_val" + node { + input: "input" + output: "tmp" + name: "n0" + op_type: "Shape" + domain: "" + } + node { + input: "tmp" + output: "tmp_0" + name: "n1" + op_type: "Size" + domain: "" + } + node { + output: "tmp_1" + name: "n2" + op_type: "Constant" + attribute { + name: "value_int" + i: 0 + type: INT + } + domain: "" + } + node { + input: "tmp_0" + input: "tmp_1" + output: "return_val" + name: "n3" + op_type: "Equal" + domain: "" + } + doc_string: "Return whether the input has rank 0, or is a scalar." + opset_import { + domain: "" + version: 18 + } + domain: "pkg.onnxscript.torch_lib.common" +} + +""" + +ort_inputs = {"input_0": array(0.8965, dtype=float16)} + +# Set up the inference session +session_options = ort.SessionOptions() +session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL +onnx_model = onnx.ModelProto() +google.protobuf.text_format.Parse(onnx_model_text, onnx_model) + +# Uncomment this line to save the model to a file for examination +# onnx.save_model(onnx_model, "test_output_match_opinfo__linalg_vector_norm_cpu_float16.onnx") + +onnx.checker.check_model(onnx_model) +session = ort.InferenceSession(onnx_model.SerializeToString(), session_options, providers=("CPUExecutionProvider",)) + +# Run the model +for _ in range(N): + ort_outputs = session.run(None, ort_inputs) diff --git a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc new file mode 100644 index 0000000000000..0412000e04e1b --- /dev/null +++ b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/optimizer/initializer.h" +#include "orttraining/core/optimizer/conv1d_replacement.h" +#include "core/graph/graph_utils.h" + +/* + In LoRA code, it will use conv1d to do projection for qkv, + while the conv1d calculation is mathematically equivalent to MatMul, and MatMul is much faster than conv1d in GPU. + The graph transformation is doing the following graph substitution: + 1. The input graph is: + conv_input conv_weight + \ / + \ / + conv1d + + 2. The output graph is as follows, + the number of MatMul is equal to attribute "group" of conv1d + conv_input conv1d.group conv_weight conv1d.group + \ / \ / + \ / Squeeze / + \ / \ / + Split Split + / / ... \ / / ... \ + / / ... \ / / ... \ + / / ... \ / / ... \ + input0 input1 ... inputN weight0 weight1 ... weightN + \ \ \ / / / + \ \ \ / / / + \ \ \ / / / + \ \ X / / + \ \ / \ / / + \ \ / X / + \ X / \ / + \ / \ / \ / + MatMul MatMul ... MatMul + \ | ... / + \ | / + \ | / +*/ +namespace onnxruntime { +bool NodeCanBeReplacedByMatmul(const Node& node) { + // If node type is Conv, and attr "dilations" is 1, "kernel_shape" is 1, "stride" is 1, group is 1 or 2, + // then it can be replaced by MatMul + // Kernel_shape is 1 means it is conv1d + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11})) { + return false; + } + const auto* dilations = graph_utils::GetNodeAttribute(node, "dilations"); + const auto* kernel_shape = graph_utils::GetNodeAttribute(node, "kernel_shape"); + const auto* stride = graph_utils::GetNodeAttribute(node, "strides"); + const auto* group = graph_utils::GetNodeAttribute(node, "group"); + if (dilations == nullptr || kernel_shape == nullptr || stride == nullptr || group == nullptr) { + return false; + } + if ((dilations->ints_size() && dilations->ints(0) != 1) || + (kernel_shape->ints_size() && kernel_shape->ints(0) != 1) || + (stride->ints_size() && stride->ints(0) != 1) || + group->i() >= 3) { + return false; + } + + return true; +} + +void Conv1dToMatmul(Graph& graph, Node& conv) { + // Shape of conv1d input: [batch_size, in_channels, in_length] + // Shape of conv1d weight:[output_channels, input_channels/group, kernel_shape], kernel_shape is 1 + // We need to split the input into "group", and squeeze&split the weight, and then do MatMul + const std::string node_description("Conv1dReplacement"); + auto execution_provider_type = conv.GetExecutionProviderType(); + // 1. Split conv input + auto group_attr = graph_utils::GetNodeAttribute(conv, "group"); + int64_t group_num = 1; // default group is 1 from ONNX schema + if (group_attr != nullptr) { + group_num = group_attr->i(); + } + auto conv1d_input = conv.MutableInputDefs()[0]; + std::vector conv1d_input_splitted_outputs; + for (int i = 0; i < group_num; i++) { + conv1d_input_splitted_outputs.push_back(&graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("input_split_output"), nullptr)); + } + auto& input_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, {conv1d_input}, + {conv1d_input_splitted_outputs}); + input_split.SetExecutionProviderType(execution_provider_type); + input_split.AddAttribute("axis", int64_t(1)); + auto onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + if (onnx_opset_version >= 18) { + input_split.AddAttribute("num_outputs", group_num); + } + // 2. Squeeze conv weight + auto conv1d_weight = conv.MutableInputDefs()[1]; + auto weight_squeeze_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("weight_squeeze_output"), nullptr); + auto& weight_squeeze = graph.AddNode(graph.GenerateNodeName("WeightSqueeze"), "Squeeze", + node_description, {conv1d_weight}, {weight_squeeze_output}); + if (onnx_opset_version > 12) { + // After onnx version 12, squeeze node has axes as input instead of attribute + ONNX_NAMESPACE::TensorProto initializer_proto; + initializer_proto.set_name(graph.GenerateNodeName("ConstAsInitializer")); + initializer_proto.add_dims(static_cast(1)); + initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + InlinedVector initializer_proto_value{2}; + initializer_proto.set_raw_data(initializer_proto_value.data(), initializer_proto_value.size() * sizeof(int64_t)); + auto& axes_input = graph_utils::AddInitializer(graph, initializer_proto); + // Squeeze node doesn't have opschema here, so we need to set input args count manually + weight_squeeze.MutableInputArgsCount().resize(2); + graph_utils::AddNodeInput(weight_squeeze, 1, axes_input); + } else { + weight_squeeze.AddAttribute("axes", std::vector{2}); + } + weight_squeeze.SetExecutionProviderType(execution_provider_type); + // 3. Split conv weight + std::vector conv1d_weight_splitted_outputs; + for (int i = 0; i < group_num; i++) { + conv1d_weight_splitted_outputs.push_back(&graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("weight_split_output"), nullptr)); + } + auto& weight_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, + {weight_squeeze_output}, {conv1d_weight_splitted_outputs}); + weight_split.AddAttribute("axis", int64_t(0)); + weight_split.SetExecutionProviderType(execution_provider_type); + if (onnx_opset_version >= 18) { + weight_split.AddAttribute("num_outputs", group_num); + } + // 4. Do MatMul + std::vector matmul_outputs; + for (int i = 0; i < group_num; i++) { + auto matmul_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("matmul_output"), nullptr); + matmul_outputs.push_back(matmul_output); + auto& matmul = graph.AddNode(graph.GenerateNodeName("Matmul"), "MatMul", node_description, + {conv1d_weight_splitted_outputs[i], conv1d_input_splitted_outputs[i]}, + {matmul_output}); + matmul.SetExecutionProviderType(execution_provider_type); + } + // 5. Concat matmul outputs + auto& concat_node = graph.AddNode(graph.GenerateNodeName("Concat"), "Concat", node_description, + matmul_outputs, {}); + concat_node.SetExecutionProviderType(execution_provider_type); + concat_node.AddAttribute("axis", int64_t(1)); + // 6. Clean up - delted original "conv" node, its output is replaced by concat_node + graph_utils::FinalizeNodeFusion(graph, concat_node, conv); +} + +Status Conv1dReplacement::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + for (auto node_index : node_topology_list) { + auto* node_ptr = graph.GetNode(node_index); + if (!node_ptr) + continue; // node was removed + auto& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + if (NodeCanBeReplacedByMatmul(node)) { + LOGS(logger, VERBOSE) << "lora conv1d replacement, node name: " + node.Name(); + Conv1dToMatmul(graph, node); + modified = true; + } + } + return Status::OK(); +} +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/conv1d_replacement.h b/orttraining/orttraining/core/optimizer/conv1d_replacement.h new file mode 100644 index 0000000000000..740f13c76fd6f --- /dev/null +++ b/orttraining/orttraining/core/optimizer/conv1d_replacement.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +class Conv1dReplacement : public GraphTransformer { + public: + Conv1dReplacement(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("Conv1dReplacement", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 57d76577f1ba7..6193a1d10c095 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -72,6 +72,7 @@ #ifdef ENABLE_TRAINING_TORCH_INTEROP #include "orttraining/core/optimizer/pythonop_rewriter.h" #endif +#include "orttraining/core/optimizer/conv1d_replacement.h" namespace onnxruntime { namespace training { @@ -194,6 +195,7 @@ std::vector> GeneratePreTrainingTransformers( // Once we have a CPU kernel for PadAndUnflatten, we can remove the guard. transformers.emplace_back(std::make_unique(compatible_eps, config.sparse_embedding_input_names)); + transformers.emplace_back(std::make_unique(compatible_eps)); #endif } diff --git a/orttraining/orttraining/python/checkpointing_utils.py b/orttraining/orttraining/python/checkpointing_utils.py deleted file mode 100644 index 460b9982297d1..0000000000000 --- a/orttraining/orttraining/python/checkpointing_utils.py +++ /dev/null @@ -1,127 +0,0 @@ -import os - -import torch - - -def list_checkpoint_files(checkpoint_dir, checkpoint_prefix, extension=".ort.pt"): - ckpt_file_names = [f for f in os.listdir(checkpoint_dir) if f.startswith(checkpoint_prefix)] - ckpt_file_names = [f for f in ckpt_file_names if f.endswith(extension)] - ckpt_file_names = [os.path.join(checkpoint_dir, f) for f in ckpt_file_names] - - assert len(ckpt_file_names) > 0, 'No checkpoint files found with prefix "{}" in directory {}.'.format( - checkpoint_prefix, checkpoint_dir - ) - return ckpt_file_names - - -def get_checkpoint_name(prefix, is_partitioned, world_rank=None, world_size=None): - SINGLE_CHECKPOINT_FILENAME = "{prefix}.ort.pt" # noqa: N806 - MULTIPLE_CHECKPOINT_FILENAME = "{prefix}.ZeRO.{world_rank}.{world_size}.ort.pt" # noqa: N806 - - if is_partitioned: - filename = MULTIPLE_CHECKPOINT_FILENAME.format( - prefix=prefix, world_rank=world_rank, world_size=(world_size - 1) - ) - else: - filename = SINGLE_CHECKPOINT_FILENAME.format(prefix=prefix) - - return filename - - -def _split_state_dict(state_dict): - optimizer_keys = ["Moment_1_", "Moment_2_", "Update_Count_", "Step"] - split_sd = {"optimizer": {}, "fp32_param": {}, "fp16_param": {}} - for k, v in state_dict.items(): - mode = "fp32_param" - for optim_key in optimizer_keys: - if k.startswith(optim_key): - mode = "optimizer" - break - if k.endswith("_fp16"): - mode = "fp16_param" - split_sd[mode][k] = v - return split_sd - - -class CombineZeroCheckpoint: - def __init__(self, checkpoint_files, clean_state_dict=None): - assert len(checkpoint_files) > 0, "No checkpoint files passed" - self.checkpoint_files = checkpoint_files - self.clean_state_dict = clean_state_dict - self.world_size = int(self.checkpoint_files[0].split("ZeRO")[1].split(".")[2]) + 1 - assert len(self.checkpoint_files) == self.world_size, f"Could not find {self.world_size} files" - self.weight_shape_map = dict() - self.sharded_params = set() - - def _split_name(self, name: str): - name_split = name.split("_view_") - view_num = None - if len(name_split) > 1: - view_num = int(name_split[1]) - optimizer_key = "" - mp_suffix = "" - if name_split[0].startswith("Moment_1"): - optimizer_key = "Moment_1_" - elif name_split[0].startswith("Moment_2"): - optimizer_key = "Moment_2_" - elif name_split[0].startswith("Update_Count"): - optimizer_key = "Update_Count_" - elif name_split[0].endswith("_fp16"): - mp_suffix = "_fp16" - param_name = name_split[0] - if optimizer_key: - param_name = param_name.split(optimizer_key)[1] - param_name = param_name.split("_fp16")[0] - return param_name, optimizer_key, view_num, mp_suffix - - def _update_weight_statistics(self, name, value): - if name not in self.weight_shape_map: - self.weight_shape_map[name] = value.size() # original shape of tensor - - def _reshape_tensor(self, key): - value = self.aggregate_state_dict[key] - weight_name, _, _, _ = self._split_name(key) - set_size = self.weight_shape_map[weight_name] - self.aggregate_state_dict[key] = value.reshape(set_size) - - def _aggregate(self, param_dict): - for k, v in param_dict.items(): - weight_name, optimizer_key, view_num, mp_suffix = self._split_name(k) - if view_num is not None: - # parameter is sharded - param_name = optimizer_key + weight_name + mp_suffix - - if param_name in self.aggregate_state_dict and optimizer_key not in ["Update_Count_"]: - self.sharded_params.add(param_name) - # Found a previous shard of the param, concatenate shards ordered by ranks - self.aggregate_state_dict[param_name] = torch.cat((self.aggregate_state_dict[param_name], v)) - else: - self.aggregate_state_dict[param_name] = v - else: - if k in self.aggregate_state_dict: - assert (self.aggregate_state_dict[k] == v).all(), "Unsharded params must have the same value" - else: - self.aggregate_state_dict[k] = v - self._update_weight_statistics(weight_name, v) - - def aggregate_checkpoints(self): - checkpoint_prefix = self.checkpoint_files[0].split(".ZeRO")[0] - self.aggregate_state_dict = dict() - - for i in range(self.world_size): - checkpoint_name = get_checkpoint_name(checkpoint_prefix, True, i, self.world_size) - rank_state_dict = torch.load(checkpoint_name, map_location=torch.device("cpu")) - if "model" in rank_state_dict: - rank_state_dict = rank_state_dict["model"] - - if self.clean_state_dict: - rank_state_dict = self.clean_state_dict(rank_state_dict) - - rank_state_dict = _split_state_dict(rank_state_dict) - self._aggregate(rank_state_dict["fp16_param"]) - self._aggregate(rank_state_dict["fp32_param"]) - self._aggregate(rank_state_dict["optimizer"]) - - for k in self.sharded_params: - self._reshape_tensor(k) - return self.aggregate_state_dict diff --git a/orttraining/orttraining/python/deprecated/__init__.py b/orttraining/orttraining/python/deprecated/__init__.py deleted file mode 100644 index 6e02db707bc47..0000000000000 --- a/orttraining/orttraining/python/deprecated/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -from onnxruntime.capi._pybind_state import TrainingParameters # noqa: F401 -from onnxruntime.capi.training.training_session import TrainingSession # noqa: F401 diff --git a/orttraining/orttraining/python/deprecated/training_session.py b/orttraining/orttraining/python/deprecated/training_session.py deleted file mode 100644 index a6900578e174b..0000000000000 --- a/orttraining/orttraining/python/deprecated/training_session.py +++ /dev/null @@ -1,68 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import os # noqa: F401 -import sys # noqa: F401 - -from onnxruntime.capi import _pybind_state as C -from onnxruntime.capi.onnxruntime_inference_collection import IOBinding # noqa: F401 -from onnxruntime.capi.onnxruntime_inference_collection import ( - InferenceSession, - Session, - check_and_normalize_provider_args, -) - - -class TrainingSession(InferenceSession): - def __init__(self, path_or_bytes, parameters, sess_options=None, providers=None, provider_options=None): - Session.__init__(self) - - if sess_options: - self._sess = C.TrainingSession(sess_options) - else: - self._sess = C.TrainingSession() - - # providers needs to be passed explicitly as of ORT 1.10 - # retain the pre-1.10 behavior by setting to the available providers. - if providers is None: - providers = C.get_available_providers() - - providers, provider_options = check_and_normalize_provider_args( - providers, provider_options, C.get_available_providers() - ) - - if isinstance(path_or_bytes, str): - config_result = self._sess.load_model(path_or_bytes, parameters, providers, provider_options) - elif isinstance(path_or_bytes, bytes): - config_result = self._sess.read_bytes(path_or_bytes, parameters, providers, provider_options) - else: - raise TypeError(f"Unable to load from type '{type(path_or_bytes)}'") - - self.loss_scale_input_name = config_result.loss_scale_input_name - - self._inputs_meta = self._sess.inputs_meta - self._outputs_meta = self._sess.outputs_meta - - def __del__(self): - if self._sess: - self._sess.finalize() - - def get_state(self): - return self._sess.get_state() - - def get_model_state(self, include_mixed_precision_weights=False): - return self._sess.get_model_state(include_mixed_precision_weights) - - def get_optimizer_state(self): - return self._sess.get_optimizer_state() - - def get_partition_info_map(self): - return self._sess.get_partition_info_map() - - def load_state(self, dict, strict=False): - self._sess.load_state(dict, strict) - - def is_output_fp32_node(self, output_name): - return self._sess.is_output_fp32_node(output_name) diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py deleted file mode 100644 index 5286c087cfb64..0000000000000 --- a/orttraining/orttraining/python/ort_trainer.py +++ /dev/null @@ -1,1241 +0,0 @@ -import io -import os -import warnings - -import numpy as np -import onnx -import torch -import torch.nn -import torch.onnx -from onnx import helper, numpy_helper -from packaging.version import Version as LooseVersion - -import onnxruntime as ort -import onnxruntime.capi.pt_patch -from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference - -from ..training import postprocess -from .checkpointing_utils import CombineZeroCheckpoint, get_checkpoint_name, list_checkpoint_files - -DEFAULT_OPSET_VERSION = 14 - - -class IODescription: - def __init__(self, name, shape, dtype=None, num_classes=None): - self.name_ = name - self.shape_ = shape - self.dtype_ = dtype - self.num_classes_ = num_classes - - -class ModelDescription: - def __init__(self, inputs, outputs): - self.inputs_ = inputs - self.outputs_ = outputs - - -def resolve_symbolic_dimensions(inputs, input_descs, output_descs): - import copy - - output_descs_copy = copy.deepcopy(output_descs) - resolved_dims = {} - for input, input_desc in zip(inputs, input_descs): - for i, axis in enumerate(input_desc.shape_): - if isinstance(axis, str): - resolved_dims[axis] = input.size()[i] - - for output_desc in output_descs_copy: - for i, axis in enumerate(output_desc.shape_): - if isinstance(axis, str): - output_desc.shape_[i] = resolved_dims[axis] - - if any(isinstance(axis, str) for axis in output_desc.shape_ for output_desc in output_descs): - raise RuntimeError("Cannot run model with unknown output dimensions") - - return output_descs_copy - - -def generate_sample(desc, device=None): - # symbolic dimensions are described with strings. set symbolic dimensions to be 1 - size = [s if isinstance(s, (int)) else 1 for s in desc.shape_] - if desc.num_classes_: - return torch.randint(0, desc.num_classes_, size, dtype=desc.dtype_).to(device) - else: - return torch.randn(size, dtype=desc.dtype_).to(device) - - -def get_device_index(device): - if type(device) == str: # noqa: E721 - # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 - device = torch.device(device) - return 0 if device.index is None else device.index - - -def input_get_device_index(input): - if isinstance(input, (list, tuple)): - device_index = get_device_index(input[0].device) - else: - device_index = get_device_index(input.device) - - return device_index - - -def get_all_gradients_finite_arg_name(session): - all_fp16_or_fp32_gradients_finite_node_args = [x for x in session._outputs_meta if "all_gradients_finite" in x.name] - if len(all_fp16_or_fp32_gradients_finite_node_args) < 1: - raise RuntimeError( - "Failed to find a group NodeArg with name that matches 'all_gradients_finite'\ - from the training session." - ) - - return all_fp16_or_fp32_gradients_finite_node_args[0].name - - -def get_group_accumulated_gradients_output_node_arg_name(session): - # TODO: get the constant string via pybind. - # optimizer_graph_builder BuildGroupNode with fixed string: 'Group_Accumulated_Gradients' - accumulated_gradients_output_node_args = [ - x for x in session._outputs_meta if "Group_Accumulated_Gradients" in x.name - ] - if len(accumulated_gradients_output_node_args) != 1: - raise RuntimeError( - "Failed to find a group NodeArg with name that matches 'Group_Accumulated_Gradients'\ - from the training session." - ) - - return accumulated_gradients_output_node_args[0].name - - -def ort_training_session_run_helper(session, iobinding, inputs, input_descs, output_descs, device, run_options=None): - for input, input_desc in zip(inputs, input_descs): - device_index = input_get_device_index(input) - iobinding.bind_input( - input_desc.name_, - input.device.type, - device_index, - dtype_torch_to_numpy(input.dtype), - list(input.size()), - input.data_ptr(), - ) - - output_descs_resolved = resolve_symbolic_dimensions(inputs, input_descs, output_descs) - torch_outputs = {} - for output_desc in output_descs_resolved: - torch_tensor = torch.zeros( - output_desc.shape_, - device=device, - dtype=output_desc.eval_dtype_ if hasattr(output_desc, "eval_dtype_") else output_desc.dtype_, - ) - iobinding.bind_output( - output_desc.name_, - torch_tensor.device.type, - get_device_index(device), - dtype_torch_to_numpy(torch_tensor.dtype), - list(torch_tensor.size()), - torch_tensor.data_ptr(), - ) - torch_outputs[output_desc.name_] = torch_tensor - - session.run_with_iobinding(iobinding, run_options) - return torch_outputs - - -def FuseSofmaxNLLToSoftmaxCE(onnx_model): # noqa: N802 - nll_count = 0 - while True: - nll_count = nll_count + 1 - nll_loss_node = None - nll_loss_node_index = 0 - for nll_loss_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007 - if node.op_type == "nll_loss" or node.op_type == "NegativeLogLikelihoodLoss": - nll_loss_node = node - break - - if nll_loss_node is None: - break - - softmax_node = None - softmax_node_index = 0 - label_input_name = None - weight_input_name = None - for softmax_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007 - if node.op_type == "LogSoftmax": - # has to be connected to nll_loss - if len(nll_loss_node.input) > 2: - weight_input_name = nll_loss_node.input[2] - if node.output[0] == nll_loss_node.input[0]: - softmax_node = node - label_input_name = nll_loss_node.input[1] - break - elif node.output[0] == nll_loss_node.input[1]: - softmax_node = node - label_input_name = nll_loss_node.input[0] - break - else: - if softmax_node is not None: - break - - if softmax_node is None: - break - - # delete nll_loss and LogSoftmax nodes in order - if nll_loss_node_index < softmax_node_index: - del onnx_model.graph.node[softmax_node_index] - del onnx_model.graph.node[nll_loss_node_index] - else: - del onnx_model.graph.node[nll_loss_node_index] - del onnx_model.graph.node[softmax_node_index] - - probability_output_name = softmax_node.output[0] - node = onnx_model.graph.node.add() - inputs = ( - [softmax_node.input[0], label_input_name, weight_input_name] - if weight_input_name - else [softmax_node.input[0], label_input_name] - ) - node.CopyFrom( - onnx.helper.make_node( - "SparseSoftmaxCrossEntropy", - inputs, - [nll_loss_node.output[0], probability_output_name], - "nll_loss_node_" + str(nll_count), - ) - ) - - return onnx_model - - -def delete_input_with_name(input, name): - index = 0 - for i in input: - if i.name == name: - del input[index] - break - index = index + 1 - - -# reference: -# https://docs.scipy.org/doc/numpy-1.13.0/user/basics.types.html -# https://pytorch.org/docs/stable/tensors.html -# also must map to types accepted by: -# MLDataType NumpyTypeToOnnxRuntimeType(int numpy_type) -def dtype_torch_to_numpy(torch_dtype): - if torch_dtype == torch.float64 or torch_dtype == torch.double: - return np.float64 - elif torch_dtype == torch.float32 or torch_dtype == torch.float: - return np.float32 - elif torch_dtype == torch.float16 or torch_dtype == torch.half: - return np.float16 - elif torch_dtype == torch.int64 or torch_dtype == torch.long: - return np.longlong - elif torch_dtype == torch.int32 or torch_dtype == torch.int: - return np.int32 - elif torch_dtype == torch.int16 or torch_dtype == torch.short: - return np.int16 - elif torch_dtype == torch.bool: - return bool - else: - raise Exception("Torch type to numpy type mapping unavailable for: " + str(torch_dtype)) - - -class model_loss_cls(torch.nn.Module): # noqa: N801 - def __init__(self, model, loss_fn): - super().__init__() - self.model_ = model - self.loss_fn_ = loss_fn - - def forward(self, *inputs): - # here we assume input can be unpacked into input and label - input, label = inputs[:-1], inputs[-1] - preds = self.model_(*input) - return self.loss_fn_(preds, label), preds - - -class WrapModel(torch.nn.Module): - def __init__(self, model, loss_fn, input_names): - super().__init__() - self.model_ = model - self.loss_fn_ = loss_fn - self.input_names_ = input_names - - def forward(self, *inputs): - import inspect - - # *inputs is given by torch trace. It is in the order of input_names. - # model_ takes input in a order (which can be obtained via inspect.signature(model.forward)) different than input_names. - sig = inspect.signature(self.model_.forward) - list(sig.parameters.keys()) - - input_dict = {} - for key in sig.parameters: - if key in self.input_names_: - input_dict[key] = inputs[self.input_names_.index(key)] - - model_out = self.model_(**input_dict) - if self.loss_fn_ is None: - return model_out - - label = inputs[-1] - preds = model_out - return self.loss_fn_(preds, label), preds - - -def wrap_for_input_match(model, loss_fn, input_names): - import inspect - - sig = inspect.signature(model.forward) - ordered_list_keys = list(sig.parameters.keys()) - if loss_fn: - sig_loss = inspect.signature(loss_fn) - if len(sig_loss.parameters) != 2: - raise RuntimeError("loss function should take two arguments - predict and label.") - - # label shall be the second input to loss_fn. - ordered_list_keys = [*ordered_list_keys, list(sig_loss.parameters.keys())[1]] - - # name match is needed only when input_names are a subset - # of expected inputs (inputs to model and loss_fn combined). - if len(input_names) > len(ordered_list_keys): - # this is likely the case where input arguments are packed. - # TODO: to unpack the input argument. - return model_loss_cls(model, loss_fn) if loss_fn else model - elif len(input_names) == len(ordered_list_keys): - # in this case, we do not require name match. - return model_loss_cls(model, loss_fn) if loss_fn else model - - if not all(x in ordered_list_keys for x in input_names): - # model desc has name(s) not matching the model signature. We cannot do anything in this case. - # better to warning the user. - return model_loss_cls(model, loss_fn) if loss_fn else model - - # if input_names match ordered_list_keys, there is not need for wrapping - match = True - for i, input_name in enumerate(input_names): - if input_name != ordered_list_keys[i]: - match = False - break - - if match: - return model_loss_cls(model, loss_fn) if loss_fn else model - - model = WrapModel(model, loss_fn, input_names) - - return model - - -def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, opset_version=DEFAULT_OPSET_VERSION): - # example: {input0:{0:'batch'}, input1:{0:'batch'}} - dynamic_axes = {} - for input in model_desc.inputs_: - symbolic_axis = {} - for i, axis in enumerate(input.shape_): - if isinstance(axis, str): - symbolic_axis[i] = axis - if len(symbolic_axis): - dynamic_axes[input.name_] = symbolic_axis - - for output in model_desc.outputs_: - symbolic_axis = {} - for i, axis in enumerate(output.shape_): - if isinstance(axis, str): - symbolic_axis[i] = axis - if len(symbolic_axis): - dynamic_axes[output.name_] = symbolic_axis - - input_names = [input.name_ for input in model_desc.inputs_] - output_names = [output.name_ for output in model_desc.outputs_] - - if isinstance(inputs, torch.Tensor): - inputs = [inputs] - if isinstance(inputs, dict): - sample_inputs = [inputs[k.name_].to(device=device) for k in model_desc.inputs_] - elif isinstance(inputs, (list, tuple)): - sample_inputs = [input.to(device=device) for i, input in enumerate(inputs) if i < len(model_desc.inputs_)] - else: - raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.") - - # pytorch onnx exporter/trace does not try to match argument names. - # e.g. for models with optional inputs, it requires all inputs be present. - # this is a problem because the model graph depends on inputs provided. - model = wrap_for_input_match(model, loss_fn, input_names) - - model.eval() - with torch.no_grad(): - import copy - - # Deepcopy inputs, since input values may change after model run. - sample_inputs_copy = copy.deepcopy(sample_inputs) - try: - # Deepcopy model, in case model is stateful and changes after model run. - model_copy = copy.deepcopy(model) - except Exception: - model_copy = model - warnings.warn( - "This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX." - " Compute will continue, but unexpected results may occur!" - ) - - sample_outputs = model_copy(*sample_inputs_copy) - if isinstance(sample_outputs, torch.Tensor): - sample_outputs = [sample_outputs] - for sample_output, output_desc in zip(sample_outputs, model_desc.outputs_): - output_desc.dtype_ = sample_output.dtype - model.train() - - f = io.BytesIO() - - # Other export options to use(this is for backward compatibility). - other_export_options = {} - other_export_options["training"] = True - - # This option was added after 1.4 release. - if LooseVersion(torch.__version__) > LooseVersion("1.4.0") and LooseVersion(torch.__version__) < LooseVersion( - "1.10.0" - ): - other_export_options["enable_onnx_checker"] = False - # This option was added after 1.6 release. - if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): - other_export_options["training"] = torch.onnx.TrainingMode.TRAINING - - # Deepcopy inputs, since input values may change after model run. - import copy - - sample_inputs_copy = copy.deepcopy(sample_inputs) - - # Enable contrib ops export from PyTorch - from onnxruntime.tools import pytorch_export_contrib_ops - - pytorch_export_contrib_ops.register() - - torch.onnx._export( - model, - tuple(sample_inputs_copy), - f, - input_names=input_names, - output_names=output_names, - opset_version=opset_version, - dynamic_axes=dynamic_axes, - do_constant_folding=False, - **other_export_options, - ) - - onnx_model = onnx.load_model_from_string(f.getvalue()) - - # Remove 'model_.' prefix introduced by model wrapper for initializers. - if isinstance(model, (WrapModel, model_loss_cls)): - replace_name_dict = {} - for n in onnx_model.graph.initializer: - if n.name.startswith("model_."): - replace_name_dict[n.name] = n.name[len("model_.") :] - n.name = replace_name_dict[n.name] - for n in onnx_model.graph.node: - for i, name in enumerate(n.input): - if name in replace_name_dict: - n.input[i] = replace_name_dict[name] - - return onnx_model - - -def create_ort_training_session_with_optimizer( - model, - device, - training_optimizer_name, - lr_params_feed_name, - map_optimizer_attributes, - world_rank=-1, - world_size=1, - gradient_accumulation_steps=1, - bind_parameters=False, - use_mixed_precision=False, - allreduce_post_accumulation=False, - deepspeed_zero_stage=0, - enable_grad_norm_clip=True, - frozen_weights=[], # noqa: B006 - opset_version=DEFAULT_OPSET_VERSION, - use_deterministic_compute=False, - use_memory_efficient_gradient=False, - enable_adasum=False, - optimized_model_filepath="", -): - output_name = model.graph.output[0].name - ort_parameters = ort.TrainingParameters() - ort_parameters.loss_output_name = output_name - ort_parameters.use_mixed_precision = use_mixed_precision - ort_parameters.world_rank = world_rank - ort_parameters.world_size = world_size - ort_parameters.gradient_accumulation_steps = gradient_accumulation_steps - ort_parameters.allreduce_post_accumulation = allreduce_post_accumulation - ort_parameters.deepspeed_zero_stage = deepspeed_zero_stage - ort_parameters.enable_grad_norm_clip = enable_grad_norm_clip - ort_parameters.set_gradients_as_graph_outputs = False - ort_parameters.use_memory_efficient_gradient = use_memory_efficient_gradient - ort_parameters.enable_adasum = enable_adasum - output_types = {} - for output in model.graph.output: - output_types[output.name] = output.type.tensor_type - - # pybind does not allow to add directly to ort_parameters.weights_to_train. - # Have to work around by using a temporary weights_to_train. - torch_params = {} - optimizer_attributes_map = {} - optimizer_int_attributes_map = {} - - unused_frozen_weights = [n for n in frozen_weights if n not in [i.name for i in model.graph.initializer]] - if unused_frozen_weights: - raise RuntimeError(f"{unused_frozen_weights} in frozen_weights not found in model weights.") - - weights_to_train = set() - for initializer in model.graph.initializer: - if initializer.name in frozen_weights: - continue - weights_to_train.add(initializer.name) - if map_optimizer_attributes is not None: - attributes = map_optimizer_attributes(initializer.name) - optimizer_attributes_map[initializer.name] = {} - optimizer_int_attributes_map[initializer.name] = {} - for k, v in attributes.items(): - if isinstance(v, float): - optimizer_attributes_map[initializer.name][k] = v - elif isinstance(v, int): - optimizer_int_attributes_map[initializer.name][k] = v - else: - raise ValueError("Optimizer attributes must be either float or int.") - else: - optimizer_attributes_map[initializer.name] = {} - optimizer_int_attributes_map[initializer.name] = {} - - if bind_parameters: - for initializer in model.graph.initializer: - torch_tensor = torch.nn.Parameter(torch.as_tensor(numpy_helper.to_array(initializer), device=device)) - delete_input_with_name(model.graph.input, initializer.name) - model.graph.input.extend( - [helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims)] - ) - torch_params[initializer.name] = torch_tensor - - del model.graph.initializer[:] - - ort_parameters.weights_to_train = weights_to_train - ort_parameters.training_optimizer_name = training_optimizer_name - ort_parameters.lr_params_feed_name = lr_params_feed_name - ort_parameters.optimizer_attributes_map = optimizer_attributes_map - ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map - - sessionOptions = ort.SessionOptions() # noqa: N806 - sessionOptions.use_deterministic_compute = use_deterministic_compute - if len(optimized_model_filepath) > 0: - sessionOptions.optimized_model_filepath = optimized_model_filepath - session = ort.TrainingSession(model.SerializeToString(), ort_parameters, sessionOptions) - train_io_binding = session.io_binding() - eval_io_binding = session.io_binding() - - if bind_parameters: - for param in torch_params: - torch_tensor = torch_params[param] - - train_io_binding.bind_input( - param, - torch_tensor.device.type, - get_device_index(torch_tensor.device), - dtype_torch_to_numpy(torch_params[param].dtype), - list(torch_tensor.size()), - torch_tensor.data_ptr(), - ) - eval_io_binding.bind_input( - param, - torch_tensor.device.type, - get_device_index(torch_tensor.device), - dtype_torch_to_numpy(torch_params[param].dtype), - list(torch_tensor.size()), - torch_tensor.data_ptr(), - ) - - return session, train_io_binding, eval_io_binding, output_name, torch_params, output_types - - -def save_checkpoint( - model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", checkpoint_state_dict=None, include_optimizer_state=True -): - if checkpoint_state_dict is None: - checkpoint_state_dict = {"model": model.state_dict(include_optimizer_state)} - else: - checkpoint_state_dict.update({"model": model.state_dict(include_optimizer_state)}) - - assert os.path.exists(checkpoint_dir), f"ERROR: Checkpoint directory doesn't exist: {checkpoint_dir}" - - checkpoint_name = get_checkpoint_name( - checkpoint_prefix, model.deepspeed_zero_stage_, model.world_rank, model.world_size - ) - checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) - - if os.path.exists(checkpoint_file): - warnings.warn(f"{checkpoint_file} already exists, overwriting.") - - torch.save(checkpoint_state_dict, checkpoint_file) - - -def _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict): - checkpoint_name = get_checkpoint_name(checkpoint_prefix, is_partitioned, model.world_rank, model.world_size) - checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) - - if is_partitioned: - assert_msg = ( - f"Couldn't find checkpoint file {checkpoint_file}." - "Optimizer partitioning is enabled using ZeRO. Please make sure that the " - f"checkpoint file exists for rank {model.world_rank} of {model.world_size}." - ) - else: - assert_msg = f"Couldn't find checkpoint file {checkpoint_file}." - - assert os.path.exists(checkpoint_file), assert_msg - - checkpoint_state = torch.load(checkpoint_file, map_location="cpu") - - model.load_state_dict(checkpoint_state["model"], strict=strict) - del checkpoint_state["model"] - return checkpoint_state - - -def _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict): - checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix) - - ckpt_agg = CombineZeroCheckpoint(checkpoint_files) - aggregate_state_dict = ckpt_agg.aggregate_checkpoints() - - model.load_state_dict(aggregate_state_dict, strict=strict) - - # aggregate other keys in the state_dict. - # Values will be overwritten for matching keys among workers - all_checkpoint_states = {} - for checkpoint_file in checkpoint_files: - checkpoint_state = torch.load(checkpoint_file, map_location="cpu") - del checkpoint_state["model"] - all_checkpoint_states.update(checkpoint_state) - return all_checkpoint_states - - -def load_checkpoint(model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False): - checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix) - is_partitioned = False - if len(checkpoint_files) > 1: - warnings.warn( - f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." - "Attempting to load ZeRO checkpoint." - ) - is_partitioned = True - if (not model.deepspeed_zero_stage_) and is_partitioned: - return _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict) - else: - return _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict) - - -class ORTTrainer: - def __init__( - self, - model, - loss_fn, - model_desc, - training_optimizer_name, - map_optimizer_attributes, - learning_rate_description, - device, - gradient_accumulation_steps=1, - world_rank=0, - world_size=1, - use_mixed_precision=False, - allreduce_post_accumulation=False, - global_step=0, - get_lr_this_step=None, - loss_scaler=None, - deepspeed_zero_stage=0, - enable_grad_norm_clip=True, - frozen_weights=[], # noqa: B006 - _opset_version=DEFAULT_OPSET_VERSION, - _enable_internal_postprocess=True, - _extra_postprocess=None, - _use_deterministic_compute=False, - use_memory_efficient_gradient=False, - run_symbolic_shape_infer=False, - enable_adasum=False, - optimized_model_filepath="", - ): - super().__init__() - """ - Initialize ORTTrainer. - - Args: - - model: one of - - a PyTorch model (class that inherits from torch.nn.Module) - - a combined PyTorch model and loss function. - Inputs to this combined PyTorch model are a concatenation of the - model's input and the loss function's label input. - Outputs are a concatenation of the loss function's output and the - model's output. - - a combined ONNX model and loss function. - loss_fn: one of - - a PyTorch loss function if 'model' is a PyTorch model. A loss - function takes two inputs (prediction, label) and outputs a loss - tensor. - - None if model is already combined with a loss function. - model_desc: Specify input/output shapes, types, and names. - Must be consistent with the training model. - training_optimizer_name: one of - - 'SGDOptimizer' - - 'AdamOptimizer' - - 'LambOptimizer' - map_optimizer_attributes: for optimizers with weight-dependent - parameters. A callable that maps weight name to a set of optimization - parameters. - Defaults to None. - learning_rate_description: the name, shape and type of the learning - rate in form of IODescription(Learning_Rate_Name, [1,], torch.float32). - Because learning_rate is an input to the training model, - Learning_Rate_Name must be specified so that there is no name conflict - within the model. - device: device to store tensors (e.g. 'cpu', 'cuda', 'cuda:'). - gradient_accumulation_steps: number of training steps to accumulate - gradients before averaging and applying them. - Defaults to 1. - world_rank: rank id used for distributed training. - Defaults to 0. - world_size: number of ranks participating in distributed training. - Defaults to 1. - use_mixed_precision: flag to enable mixed precision (aka fp16). - Defaults to False. - allreduce_post_accumulation: controls whether overlaping gradient - computation is applied with allreduce. - Defaults to False. - global_step: training step that is used as input to 'get_lr_this_step'. - Defaults to 0. - get_lr_this_step: functor used as learning rate scheduler. - It uses 'global_step' as input. - Defaults to None. - loss_scaler: updates loss scale automatically when 'use_mixed_precision' - is specified. - Defaults to None. - deepspeed_zero_stage: controls whether to partition state using the DeepSpeed ZeRO technique. Stages 0 and 1 are supported. - Defaults to 0 (disabled). - enable_grad_norm_clip: enables gradient norm clipping. - Defaults to True. - frozen_weights: list of model parameters to be frozen (not trained). - Defaults to []. - _enable_internal_postprocess: whether to run or not the internal postprocesses. - Defaults to True - _extra_postprocess: a callable to postprocess the ONNX model that is converted from PyTorch. - Defaults to None - use_memory_efficient_gradient: use memory aware gradient builder. - Defaults to False - run_symbolic_shape_infer: run symbolic shape inference - Defaults to False - optimized_model_filepath: path to output the optimized training graph. - Defaults to "" (no output). - """ - warnings.warn( - "ORTTrainer is deprecated and will be removed in ort release 1.14. Please use ORTModule instead.", - FutureWarning, - ) - warnings.warn( - "DISCLAIMER: This is an early version of an experimental training API and it is subject to change. DO NOT create production applications with it" - ) - self.is_train = True - - self.torch_model_ = None - self.onnx_model_ = None - self._enable_internal_postprocess = _enable_internal_postprocess - self._extra_postprocess = _extra_postprocess - - if isinstance(model, torch.nn.Module): - self.torch_model_ = model - self.loss_fn_ = loss_fn - self._torch_state_dict_keys = list(model.state_dict().keys()) - else: - self._torch_state_dict_keys = [] - self.onnx_model_ = model - if loss_fn is not None: - warnings.warn("loss_fn is not used when creating ORTTrainer because an ONNX model is provided.") - # TODO: accept loss_fn as an onnx model. build self.onnx_model_ with model and loss_fn - self.loss_fn_ = None - - if self._enable_internal_postprocess: - postprocess.run_postprocess(self.onnx_model_) - - if self._extra_postprocess: - self._extra_postprocess(self.onnx_model_) - - self.model_desc_ = model_desc - self.input_desc_with_lr = [*self.model_desc_.inputs_, learning_rate_description] - - self.world_rank = world_rank - self.world_size = world_size - self.use_mixed_precision = use_mixed_precision - - self.session = None - self.device_ = device - self.gradient_accumulation_steps = gradient_accumulation_steps - # we use self.current_step to count calls to train_step. It is used for gradient accumulation. - # gradients are being accumulated when self.current_step is not divisible by gradient_accumulation_steps. - # gradients are updated when self.current_step is divisible by gradient_accumulation_steps. - self.current_step = 0 - - # we use self.global_step_ to count optimizations being performed. - # it is used to calculate learning rate if self.get_lr_this_step_ is provided. - self.global_step_ = global_step - self.get_lr_this_step_ = get_lr_this_step - self.loss_scaler_ = loss_scaler - - if self.get_lr_this_step_ is not None or self.loss_scaler_ is not None: - warnings.warn("It is experimental to use learning rate scheduler and loss scaler inside ORTTrainer.") - self.training_optimizer_name_ = training_optimizer_name - self.learning_rate_description_ = learning_rate_description - self.map_optimizer_attributes_ = map_optimizer_attributes - self.allreduce_post_accumulation_ = allreduce_post_accumulation - self.deepspeed_zero_stage_ = deepspeed_zero_stage - self.enable_grad_norm_clip_ = enable_grad_norm_clip - self.frozen_weights_ = frozen_weights - self.opset_version_ = _opset_version - self.state_dict_ = None - self._use_deterministic_compute = _use_deterministic_compute - self.use_memory_efficient_gradient = use_memory_efficient_gradient - self.run_symbolic_shape_infer = run_symbolic_shape_infer - self.enable_adasum = enable_adasum - self.optimized_model_filepath = optimized_model_filepath - - # use this special string to workaround a corner case that external loss_scale is passed into train_step as kwargs. - # see prepare_input_and_fetches for more details. - self.loss_scale_input_name = "default_loss_scale_input_name" - - self._init_session() - - def _init_session(self): - if self.onnx_model_ is None: - return - - self._verify_fully_optimized_model(self.onnx_model_) - - if self.run_symbolic_shape_infer: - self.onnx_model_ = SymbolicShapeInference.infer_shapes( - self.onnx_model_, auto_merge=True, guess_output_rank=True - ) - - # old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error. - # for example, load_state_dict will be called before returing the function, and it calls _init_session again - del self.session - ( - self.session, - self.train_io_binding, - self.eval_io_binding, - self.output_name, - _, - self.output_types, - ) = create_ort_training_session_with_optimizer( - self.onnx_model_, - self.device_, - self.training_optimizer_name_, - self.learning_rate_description_.name_, - self.map_optimizer_attributes_, - self.world_rank, - self.world_size, - self.gradient_accumulation_steps, - bind_parameters=False, - use_mixed_precision=self.use_mixed_precision, - allreduce_post_accumulation=self.allreduce_post_accumulation_, - deepspeed_zero_stage=self.deepspeed_zero_stage_, - enable_grad_norm_clip=self.enable_grad_norm_clip_, - frozen_weights=self.frozen_weights_, - opset_version=self.opset_version_, - use_deterministic_compute=self._use_deterministic_compute, - use_memory_efficient_gradient=self.use_memory_efficient_gradient, - enable_adasum=self.enable_adasum, - optimized_model_filepath=self.optimized_model_filepath, - ) - - self.loss_scale_input_name = self.session.loss_scale_input_name - - if self.use_mixed_precision: - self.input_desc_with_lr_and_loss_scale = [ - *self.input_desc_with_lr, - IODescription(self.loss_scale_input_name, [], torch.float32), - ] - - # ORT backend has modified model output dtype from float32 to float16. - for o_desc in self.model_desc_.outputs_: - if ( - self.use_mixed_precision - and o_desc.dtype_ == torch.float32 - and not self.session.is_output_fp32_node(o_desc.name_) - ): - o_desc.eval_dtype_ = torch.float16 - else: - o_desc.eval_dtype_ = o_desc.dtype_ - - # gradient accumulation buffers are connected to a single node with a boolean, dimension 1 tensor output. - # add a matching output to drive gradient accumulation. - if self.gradient_accumulation_steps > 1: - self.output_desc_with_group_accumulated_gradients = [ - *self.model_desc_.outputs_, - IODescription(get_group_accumulated_gradients_output_node_arg_name(self.session), [1], torch.bool), - ] - - if self.use_mixed_precision: - # when ready to use accumulated gradient with mixed precision, we need to fetch all_infinite to determine - # if the gradient is usable. - self.output_desc_with_all_fp_16_or_fp32_gradients_finite = [ - *self.model_desc_.outputs_, - IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool), - ] - - if self.state_dict_: - self.load_state_dict(self.state_dict_, self.strict_) - self.state_dict_ = None - - def _init_onnx_model(self, inputs): - if self.onnx_model_ is not None: - return - - if self.torch_model_ is not None: - # NOTE: pt model is moved to cpu to conserve gpu memory. - self.torch_model_.cpu() - # torch buffers created using 'register_buffer' are not meant to be trainable. - torch_buffers = list(dict(self.torch_model_.named_buffers()).keys()) - self.frozen_weights_ = self.frozen_weights_ + torch_buffers - self.onnx_model_ = convert_model_loss_fn_to_onnx( - self.torch_model_, - self.loss_fn_, - self.model_desc_, - torch.device("cpu"), - inputs, - opset_version=self.opset_version_, - ) - - if self._enable_internal_postprocess: - postprocess.run_postprocess(self.onnx_model_) - - if self._extra_postprocess: - self._extra_postprocess(self.onnx_model_) - - self._init_session() - - def train(self): - self.is_train = True - - def eval(self): - self.is_train = False - - def _update_onnx_model_initializers(self, state_tensors): - # replace the initializers with new value - new_weights = [] - replace_indices = [] - for i, w in enumerate(self.onnx_model_.graph.initializer): - if w.name in state_tensors: - new_weights.append(numpy_helper.from_array(state_tensors[w.name], w.name)) - replace_indices.append(i) - replace_indices.sort(reverse=True) - for w_i in replace_indices: - del self.onnx_model_.graph.initializer[w_i] - self.onnx_model_.graph.initializer.extend(new_weights) - - def state_dict(self, include_optimizer_state=True): - if not self.session: - warnings.warn( - "ONNXRuntime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling state_dict()." - ) - return {} - - # extract trained weights - session_state = self.session.get_state() - torch_state = {} - for name in session_state: - torch_state[name] = torch.from_numpy(session_state[name]) - - # extract untrained weights and buffer - for n in self.onnx_model_.graph.initializer: - if n.name not in torch_state: - torch_state[n.name] = torch.from_numpy(numpy_helper.to_array(n)) - - # Need to remove redundant initializers and name suffices to map back to original torch state names - if not include_optimizer_state and self._torch_state_dict_keys: - return {key: torch_state[key] for key in self._torch_state_dict_keys if key in torch_state} - return torch_state - - def load_state_dict(self, state_dict, strict=False): - # Note: It may happen ONNX model has not yet been initialized - # In this case we cache a reference to desired state and delay the restore until after initialization - # Unexpected behavior will result if the user changes the reference before initialization - if not self.session: - self.state_dict_ = state_dict - self.strict_ = strict - return - - # update onnx model from loaded state dict - cur_initializers_names = [n.name for n in self.onnx_model_.graph.initializer] - new_initializers = {} - - for name in state_dict: - if name in cur_initializers_names: - new_initializers[name] = state_dict[name].numpy() - elif strict: - raise RuntimeError(f"Checkpoint tensor: {name} is not present in the model.") - - self._update_onnx_model_initializers(new_initializers) - - # create new session based on updated onnx model - self.state_dict_ = None - self._init_session() - - # load training state - session_state = {name: state_dict[name].numpy() for name in state_dict} - self.session.load_state(session_state, strict) - - def save_as_onnx(self, path): - if not self.session: - warnings.warn( - "ONNXRuntime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling save_as_onnx()." - ) - return - state_tensors = self.session.get_state() - self._update_onnx_model_initializers(state_tensors) - - with open(path, "wb") as f: - f.write(self.onnx_model_.SerializeToString()) - - def _prepare_input_and_fetches( - self, input_desc_with_, internal_learning_rate, internal_loss_scale, *args, **kwargs - ): - fetches = None - if type(args) == tuple and len(args) == 1 and type(args[0]) == list: # noqa: E721 - input = tuple(args[0]) - else: - input = args - - for input_desc in input_desc_with_: - if input_desc.name_ in kwargs: - input = (*input, kwargs[input_desc.name_]) - if internal_learning_rate is not None: - input = (*input, internal_learning_rate) - if internal_loss_scale is not None: - input = (*input, internal_loss_scale) - elif self.use_mixed_precision: - # loss_scale input name is needed to call train_step, for example: - # kwargs[model.loss_scale_input_name] = loss_scale - # outputs = model.train_step(*args, **kwargs) - # However, when first time train_step is called model.loss_scale_input_name is not set. - # To workaround this problem, we use the special name 'default_loss_scale_input_name' to indicate - # the loss_scale. - if "default_loss_scale_input_name" in kwargs: - input = (*input, kwargs["default_loss_scale_input_name"]) - - fetches = None - if "fetches" in kwargs: - fetches = kwargs["fetches"] - - return input, fetches - - def train_step(self, *args, **kwargs): - """ - inputs: model inputs, labels, learning rate, and, if in mixed_precision mode, loss_scale. - outputs: if fetches is not provided, outputs are loss and - (if in mixed mode and is finishing gradient accumulation) all_finite. - if fetches is provided, outputs contains these requested with fetches. - fetches: names of requested outputs - """ - - # inputs to the ONNX model includes inputs to the original PyTorch model - # plus learning rate and loss_scale if self.use_mixed_precision is True. - # 1. when there are internal learning_rate and loss_scale (in fp16 cases) generators, - # *args and **kwargs together contain ONLY and COMPLETE inputs to the PyTorch model. - # In this case, changes to the training script is minimized. - # 2. without internal learning rate and loss scale (in fp16 cases) generators, - # *args and **kwargs passed in from the training script shall contains - # inputs to the PyTorch model plus learning_rate and loss_scale. - # it optionally contains the fetches. - # localized arguments (*args) contains inputs to the ONNX model. - # named arguments can contain both inputs, learning_rate and loss_scale, and the fetches - - learning_rate, loss_scale = None, None - if self.get_lr_this_step_ is not None: - # $args, **kwargs contains inputs to the pytorch model - lr_this_step = self.get_lr_this_step_(self.global_step_) - learning_rate = torch.tensor([lr_this_step]) - if self.loss_scaler_ is not None and self.use_mixed_precision: - loss_scale = torch.tensor([self.loss_scaler_.loss_scale_]) - - if self.onnx_model_ is None: - sample_input, _ = self._prepare_input_and_fetches(self.model_desc_.inputs_, None, None, *args, **kwargs) - self._init_onnx_model(sample_input) - - if self.use_mixed_precision: - input, fetches = self._prepare_input_and_fetches( - self.input_desc_with_lr_and_loss_scale, learning_rate, loss_scale, *args, **kwargs - ) - assert len(self.input_desc_with_lr_and_loss_scale) == len(input) - input_descs = self.input_desc_with_lr_and_loss_scale - else: - input, fetches = self._prepare_input_and_fetches( - self.input_desc_with_lr, learning_rate, loss_scale, *args, **kwargs - ) - assert len(self.input_desc_with_lr) == len(input) - input_descs = self.input_desc_with_lr - - self.current_step += 1 - - # handle gradient accumulation in fully optimized mode - run_options = None - has_if_all_finite = False - if fetches: - output_desc = [output for fetch in fetches for output in self.model_desc_.outputs_ if output.name_ == fetch] - elif self.current_step % self.gradient_accumulation_steps != 0: - run_options = ort.RunOptions() - run_options.only_execute_path_to_fetches = True - output_desc = self.output_desc_with_group_accumulated_gradients - elif self.use_mixed_precision: - has_if_all_finite = True - output_desc = self.output_desc_with_all_fp_16_or_fp32_gradients_finite - else: - output_desc = self.model_desc_.outputs_ - - if not isinstance(input, (list, tuple)): - input = (input,) - - session_run_results = ort_training_session_run_helper( - self.session, self.train_io_binding, input, input_descs, output_desc, self.device_, run_options - ) - - if has_if_all_finite: - # After session run with all_fp32_gradients_finite, we need to clear the iobinding's output state. - # Otherwise next run with only_execute_path_to_fetches will lead to gradient all reduce - # because all_fp32_gradients_finite is still in the feed. - self.train_io_binding.clear_binding_outputs() - all_finite = session_run_results[self.output_desc_with_all_fp_16_or_fp32_gradients_finite[-1].name_] - if self.loss_scaler_ is not None: - self.loss_scaler_.update_loss_scale(all_finite) - if all_finite: - # optimization has done, increase self.global_step_ - self.global_step_ = self.global_step_ + 1 - elif self.current_step % self.gradient_accumulation_steps == 0: - # optimization has done, increase self.global_step_ - self.global_step_ = self.global_step_ + 1 - - if fetches is not None: - results = [session_run_results[fetch] for fetch in fetches] - elif has_if_all_finite and self.loss_scaler_ is None: - # return descripted outputs plus the all_finite flag so that the training script can handle loss scaling. - results = [ - session_run_results[output_desc.name_] - for output_desc in self.output_desc_with_all_fp_16_or_fp32_gradients_finite - ] - else: - results = [session_run_results[output_desc.name_] for output_desc in self.model_desc_.outputs_] - return results[0] if len(results) == 1 else results - - def __call__(self, *args, **kwargs): - if self.is_train: - return self.train_step(*args, **kwargs) - else: - return self.eval_step(*args, **kwargs) - - def eval_step(self, *args, **kwargs): - """ - inputs: model inputs and/or labels. - outputs: if 'fetches' is not provided, outputs are loss and - (if in mixed mode and is finishing gradient accumulation) all_finite. - if fetches is provided, outputs contains these requested with fetches. - fetches: names of requested outputs - """ - - # with model_loss_cls, the last input is label, first output is loss - input, fetches = self._prepare_input_and_fetches(self.model_desc_.inputs_, None, None, *args, **kwargs) - - if self.onnx_model_ is None: - if self.torch_model_ is not None: - self._init_onnx_model(input) - else: - raise RuntimeError( - "Model is unintialized. Please ensure a valid ONNX model or PyTorch model is provided to this Trainer." - ) - - input_desc = self.model_desc_.inputs_[0 : len(input)] - if fetches is None: - output_desc = self.model_desc_.outputs_ - else: - output_desc = [output for fetch in fetches for output in self.model_desc_.outputs_ if output.name_ == fetch] - - if not isinstance(input, (list, tuple)): - input = (input,) - - run_options = ort.RunOptions() - run_options.only_execute_path_to_fetches = True - run_options.training_mode = False - - session_run_results = ort_training_session_run_helper( - self.session, self.eval_io_binding, input, input_desc, output_desc, self.device_, run_options - ) - - if len(session_run_results) == 1: - return session_run_results[next(iter(session_run_results.keys()))] - else: - return [session_run_results[output_desc.name_] for output_desc in output_desc] - - def _verify_fully_optimized_model(self, model): - assert len(model.graph.output) > 0 - # model's first output must be the loss tensor - if model.graph.output[0].type.tensor_type.elem_type not in { - onnx.TensorProto.FLOAT, - onnx.TensorProto.FLOAT16, - onnx.TensorProto.DOUBLE, - onnx.TensorProto.COMPLEX64, - onnx.TensorProto.COMPLEX128, - onnx.TensorProto.BFLOAT16, - onnx.TensorProto.FLOAT8E4M3FN, - onnx.TensorProto.FLOAT8E4M3FNUZ, - onnx.TensorProto.FLOAT8E5M2, - onnx.TensorProto.FLOAT8E5M2FNUZ, - }: - raise RuntimeError( - "the first output of a model to run with fully optimized ORT backend must be float types." - ) - if len(model.graph.output[0].type.tensor_type.shape.dim) != 0: - raise RuntimeError( - "the first output of a model to run with fully optimized ORT backend assumed to be loss and must be a scalar." - ) - - -class LossScaler: - def __init__( - self, - loss_scale_input_name, - is_dynamic_scale, - loss_scale=float(1 << 16), - up_scale_window=2000, - min_loss_scale=1.0, - max_loss_scale=float(1 << 24), - ): - super().__init__() - self.loss_scale_input_name_ = loss_scale_input_name - self.is_dynamic_scale_ = is_dynamic_scale - self.initial_loss_scale_ = loss_scale - self.up_scale_window_ = up_scale_window - self.min_loss_scale_ = min_loss_scale - self.max_loss_scale_ = max_loss_scale - self.loss_scale_ = loss_scale - self.stable_steps_ = 0 - - def update_loss_scale(self, is_all_finite): - if not self.is_dynamic_scale_: - return - - if is_all_finite: - self.stable_steps_ += 1 - - if self.stable_steps_ >= self.up_scale_window_: - self.loss_scale_ = min(self.max_loss_scale_, self.loss_scale_ * 2) - self.stable_steps_ = 0 - else: - self.loss_scale_ = max(self.min_loss_scale_, self.loss_scale_ / 2) - self.stable_steps_ = 0 - - def reset(self): - self.loss_scale_ = self.initial_loss_scale_ - self.stable_steps_ = 0 diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index a08e8bee99cee..bb1cb4bbd32f7 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -18,7 +18,6 @@ #include "core/session/environment.h" #include "core/session/custom_ops.h" #include "core/dlpack/dlpack_converter.h" -#include "orttraining/core/session/training_session.h" #include "orttraining/core/agent/training_agent.h" #include "orttraining/core/graph/gradient_config.h" #include "orttraining/core/graph/optimizer_config.h" @@ -113,14 +112,11 @@ struct TrainingParameters { std::unordered_set weights_to_train; std::unordered_set weights_not_to_train; - onnxruntime::training::TrainingSession::ImmutableWeights immutable_weights; - // optimizer std::string training_optimizer_name; std::string lr_params_feed_name = "Learning_Rate"; std::unordered_map> optimizer_attributes_map; std::unordered_map> optimizer_int_attributes_map; - onnxruntime::training::TrainingSession::OptimizerState optimizer_initial_state; std::unordered_map> sliced_schema; std::unordered_map sliced_axes; std::vector sliced_tensor_names; @@ -206,185 +202,6 @@ struct PyGradientGraphBuilderContext { local_registries_(local_registries) {} }; -// TODO: this method does not handle parallel optimization. -TrainingConfigurationResult ConfigureSessionForTraining( - training::PipelineTrainingSession* sess, TrainingParameters& parameters) { - // TODO tix, refactor the mpi related code to populate all fields correctly by default. - ORT_ENFORCE(parameters.data_parallel_size <= parameters.world_size, "data_parallel_size: ", parameters.data_parallel_size, ", world_size: ", parameters.world_size); - ORT_ENFORCE(parameters.horizontal_parallel_size <= parameters.world_size, "horizontal_parallel_size: ", parameters.horizontal_parallel_size, ", world_size: ", parameters.world_size); - ORT_ENFORCE(parameters.pipeline_parallel_size <= parameters.world_size, "pipeline_parallel_size: ", parameters.pipeline_parallel_size, ", world_size: ", parameters.world_size); - - // When DxHxP != the total number of ranks, we try adjusting D so that DxHxP == the total number of ranks. - if (parameters.world_size != parameters.data_parallel_size * parameters.horizontal_parallel_size * parameters.pipeline_parallel_size) { - ORT_ENFORCE(parameters.world_size % parameters.horizontal_parallel_size * parameters.pipeline_parallel_size == 0, - "D, H, P sizes are incorrect. To enable automatic correction, total number of ranks must be a divisible by HxP."); - - const auto new_data_parallel_size = parameters.world_size / (parameters.horizontal_parallel_size * parameters.pipeline_parallel_size); - parameters.data_parallel_size = new_data_parallel_size; - - const std::string msg = "Cannot distribute " + std::to_string(parameters.world_size) + " ranks for distributed computation with D=" + std::to_string(parameters.data_parallel_size) + - ", H=" + std::to_string(parameters.horizontal_parallel_size) + ", P=" + std::to_string(parameters.pipeline_parallel_size) + ", so D is automatically changed to " + std::to_string(new_data_parallel_size); - LOGS(*(sess->GetLogger()), WARNING) << msg; - } - - training::PipelineTrainingSession::TrainingConfiguration config{}; - config.weight_names_to_train = parameters.weights_to_train; - config.weight_names_to_not_train = parameters.weights_not_to_train; - config.immutable_weights = parameters.immutable_weights; - config.gradient_accumulation_steps = parameters.gradient_accumulation_steps; - - config.distributed_config.world_rank = parameters.world_rank; - config.distributed_config.world_size = parameters.world_size; - config.distributed_config.local_rank = parameters.local_rank; - config.distributed_config.local_size = parameters.local_size; - config.distributed_config.data_parallel_size = parameters.data_parallel_size; - config.distributed_config.horizontal_parallel_size = parameters.horizontal_parallel_size; - config.distributed_config.pipeline_parallel_size = parameters.pipeline_parallel_size; - config.distributed_config.num_pipeline_micro_batches = parameters.num_pipeline_micro_batches; - config.distributed_config.sliced_schema = parameters.sliced_schema; - config.distributed_config.sliced_axes = parameters.sliced_axes; - config.distributed_config.sliced_tensor_names = parameters.sliced_tensor_names; - - if (parameters.use_mixed_precision) { - training::PipelineTrainingSession::TrainingConfiguration::MixedPrecisionConfiguration mp{}; - mp.use_mixed_precision_initializers = true; - - config.mixed_precision_config = mp; - } - - if (config.distributed_config.pipeline_parallel_size > 1) { - training::PipelineTrainingSession::TrainingConfiguration::PipelineConfiguration pipeline_config; - - // Currently don't support auto-partition. User needs to pass in cut information for pipeline - pipeline_config.do_partition = true; - assert(!parameters.pipeline_cut_info_string.empty()); - - auto process_with_delimiter = [](std::string& input_str, const std::string& delimiter) { - std::vector result; - size_t pos = 0; - while ((pos = input_str.find(delimiter)) != std::string::npos) { - std::string token = input_str.substr(0, pos); - result.emplace_back(token); - input_str.erase(0, pos + delimiter.length()); - } - // push the last split of substring into result. - result.emplace_back(input_str); - return result; - }; - - auto process_cut_info = [&](std::string& cut_info_string) { - std::vector cut_list; - const std::string group_delimiter = ","; - const std::string edge_delimiter = ":"; - const std::string consumer_delimiter = "/"; - const std::string producer_consumer_delimiter = "-"; - - auto cut_info_groups = process_with_delimiter(cut_info_string, group_delimiter); - for (auto& cut_info_group : cut_info_groups) { - PipelineTrainingSession::TrainingConfiguration::CutInfo cut_info; - auto cut_edges = process_with_delimiter(cut_info_group, edge_delimiter); - for (auto& cut_edge : cut_edges) { - auto process_edge = process_with_delimiter(cut_edge, producer_consumer_delimiter); - if (process_edge.size() == 1) { - PipelineTrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0]}; - cut_info.emplace_back(edge); - } else { - ORT_ENFORCE(process_edge.size() == 2); - auto consumer_list = process_with_delimiter(process_edge[1], consumer_delimiter); - - PipelineTrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0], consumer_list}; - cut_info.emplace_back(edge); - } - } - cut_list.emplace_back(cut_info); - } - return cut_list; - }; - - pipeline_config.cut_list = process_cut_info(parameters.pipeline_cut_info_string); - config.pipeline_config = pipeline_config; - } - config.loss_name = parameters.loss_output_name; - - if (!parameters.training_optimizer_name.empty()) { - training::PipelineTrainingSession::TrainingConfiguration::OptimizerConfiguration opt{}; - opt.name = parameters.training_optimizer_name; - opt.learning_rate_input_name = parameters.lr_params_feed_name; - opt.weight_attributes_generator = [¶meters](const std::string& weight_name) { - const auto it = parameters.optimizer_attributes_map.find(weight_name); - ORT_ENFORCE( - it != parameters.optimizer_attributes_map.end(), - "Failed to find attribute map for weight ", weight_name); - return it->second; - }; - opt.weight_int_attributes_generator = [¶meters](const std::string& weight_name) { - const auto it = parameters.optimizer_int_attributes_map.find(weight_name); - ORT_ENFORCE( - it != parameters.optimizer_int_attributes_map.end(), - "Failed to find int attribute map for weight ", weight_name); - return it->second; - }; - opt.use_mixed_precision_moments = parameters.use_fp16_moments; - opt.do_all_reduce_in_mixed_precision_type = true; - // TODO: this mapping is temporary. - // For now, nccl allreduce kernel only implements for allreduce_post_accumulation - // hovorod allreduce kernel only implements for not allreduce_post_accumulation. - // eventually we will have one all reduce kernel and let opt to have - // an allreduce_post_accumulation option and remove the use_nccl option. - opt.use_nccl = parameters.allreduce_post_accumulation; - opt.deepspeed_zero = onnxruntime::training::ZeROConfig(parameters.deepspeed_zero_stage); - opt.enable_grad_norm_clip = parameters.enable_grad_norm_clip; - - // TODO reduction types - if (parameters.enable_adasum) { -#ifdef USE_CUDA - opt.adasum_reduction_type = training::AdasumReductionType::GpuHierarchicalReduction; -#else - opt.adasum_reduction_type = training::AdasumReductionType::CpuReduction; -#endif - } - - config.optimizer_config = opt; - } - - if (!parameters.optimizer_initial_state.empty()) { - config.init_optimizer_states = parameters.optimizer_initial_state; - } - - config.gradient_graph_config.use_memory_efficient_gradient = parameters.use_memory_efficient_gradient; - config.gradient_graph_config.set_gradients_as_graph_outputs = parameters.set_gradients_as_graph_outputs; - - config.graph_transformer_config.attn_dropout_recompute = parameters.attn_dropout_recompute; - config.graph_transformer_config.gelu_recompute = parameters.gelu_recompute; - config.graph_transformer_config.transformer_layer_recompute = parameters.transformer_layer_recompute; - config.graph_transformer_config.number_recompute_layers = parameters.number_recompute_layers; - config.graph_transformer_config.propagate_cast_ops_config.strategy = parameters.propagate_cast_ops_strategy; - config.graph_transformer_config.propagate_cast_ops_config.level = parameters.propagate_cast_ops_level; - config.graph_transformer_config.propagate_cast_ops_config.allow = parameters.propagate_cast_ops_allow; - - if (!parameters.model_after_graph_transforms_path.empty()) { - config.model_after_graph_transforms_path = ToPathString(parameters.model_after_graph_transforms_path); - } - if (!parameters.model_with_gradient_graph_path.empty()) { - config.model_with_gradient_graph_path = ToPathString(parameters.model_with_gradient_graph_path); - } - if (!parameters.model_with_training_graph_path.empty()) { - config.model_with_training_graph_path = ToPathString(parameters.model_with_training_graph_path); - } - - training::PipelineTrainingSession::TrainingConfigurationResult config_result{}; - - OrtPybindThrowIfError(sess->ConfigureForTraining(config, config_result)); - - TrainingConfigurationResult python_config_result{}; - if (config_result.mixed_precision_config_result.has_value()) { - const auto& mp_config_result = config_result.mixed_precision_config_result.value(); - python_config_result.loss_scale_input_name = mp_config_result.loss_scale_input_name; - } - - return python_config_result; -} - #if defined(USE_MPI) void CopyMPIContextToTrainingParameters(TrainingParameters& parameters, const logging::Logger* logger) { LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldRank(): " << MPIContext::GetInstance().GetWorldRank(); @@ -424,7 +241,7 @@ std::unordered_map> Con return py_tensor_state; } -void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn) { +void addObjectMethodsForTraining(py::module& m) { py::class_(m, "OrtValueCache") .def(py::init<>()) .def("insert", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name, OrtValue& value) { @@ -451,7 +268,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn py::class_ parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc"); parameters.def(py::init()) .def_readwrite("loss_output_name", &TrainingParameters::loss_output_name) - .def_readwrite("immutable_weights", &TrainingParameters::immutable_weights) .def_readwrite("weights_not_to_train", &TrainingParameters::weights_not_to_train) .def_readwrite("weights_to_train", &TrainingParameters::weights_to_train) .def_readwrite("sliced_tensor_names", &TrainingParameters::sliced_tensor_names) @@ -484,25 +300,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn .def_readwrite("data_parallel_size", &TrainingParameters::data_parallel_size) .def_readwrite("horizontal_parallel_size", &TrainingParameters::horizontal_parallel_size) .def_readwrite("pipeline_parallel_size", &TrainingParameters::pipeline_parallel_size) - .def("set_optimizer_initial_state", - [](TrainingParameters& parameters, const std::unordered_map>& py_state) -> void { - onnxruntime::training::TrainingSession::OptimizerState optim_state; - for (const auto& weight_it : py_state) { - auto state = weight_it.second; - NameMLValMap state_tensors; - for (auto& initializer : state) { - OrtValue ml_value; - - // InputDeflist is null because parameters havent been tied to session yet - // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) - CreateGenericMLValue(nullptr, GetAllocator(), "", initializer.second, &ml_value, true); - ThrowIfPyErrOccured(); - state_tensors.emplace(initializer.first, ml_value); - } - optim_state.emplace(weight_it.first, state_tensors); - } - parameters.optimizer_initial_state = optim_state; - }) .def_readwrite("model_after_graph_transforms_path", &TrainingParameters::model_after_graph_transforms_path) .def_readwrite("model_with_gradient_graph_path", &TrainingParameters::model_with_gradient_graph_path) .def_readwrite("model_with_training_graph_path", &TrainingParameters::model_with_training_graph_path) @@ -611,130 +408,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn }); #endif - py::class_ config_result(m, "TrainingConfigurationResult", "pbdoc(Configuration result for training.)pbdoc"); - config_result.def(py::init()) - .def_property_readonly("loss_scale_input_name", [](const TrainingConfigurationResult& result) -> py::object { - if (result.loss_scale_input_name.has_value()) { - return py::str{result.loss_scale_input_name.value()}; - } - return py::none(); - }); - - // Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user - struct PyTrainingSession : public PyInferenceSession { - PyTrainingSession(std::shared_ptr env, const PySessionOptions& so) - : PyInferenceSession(env, std::make_unique(so.value, *env)) { - } - ~PyTrainingSession() = default; - }; - - py::class_ training_session(m, "TrainingSession"); - training_session - .def(py::init([](const PySessionOptions& so) { - auto& training_env = GetTrainingEnv(); - return std::make_unique(training_env.GetORTEnv(), so); - })) - .def(py::init([]() { - auto& training_env = GetTrainingEnv(); - return std::make_unique(training_env.GetORTEnv(), GetDefaultCPUSessionOptions()); - })) - .def("finalize", [](py::object) { -#if defined(USE_MPI) -#ifdef _WIN32 - // https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices - // shutdown_mpi() is not called within MPIContext destructor because of DllMain's restriction - // call shutdown_mpi() here instead. - MPIContext::shutdown_mpi(); -#endif -#endif - }) - .def("load_model", [ep_registration_fn](PyTrainingSession* sess, const std::string& path, TrainingParameters& parameters, const std::vector& provider_types, const ProviderOptionsVector& provider_options) { - OrtPybindThrowIfError(sess->GetSessionHandle()->Load(path)); - -#if defined(USE_MPI) - bool use_nccl = parameters.allreduce_post_accumulation; - if (!use_nccl && parameters.world_size > 1) - CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger()); -#endif - const auto config_result = ConfigureSessionForTraining(static_cast(sess->GetSessionHandle()), parameters); - - ProviderOptionsVector merged_options; - ResolveExtraProviderOptions(provider_types, provider_options, merged_options); - - InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options); - - return config_result; - }) - .def("read_bytes", [ep_registration_fn](PyTrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters, const std::vector& provider_types, const ProviderOptionsVector& provider_options) { - std::istringstream buffer(serialized_model); - OrtPybindThrowIfError(sess->GetSessionHandle()->Load(buffer)); - -#if defined(USE_MPI) - bool use_nccl = parameters.allreduce_post_accumulation; - if (!use_nccl && parameters.world_size > 1) - CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger()); -#endif - const auto config_result = ConfigureSessionForTraining(static_cast(sess->GetSessionHandle()), parameters); - ProviderOptionsVector merged_options; - ResolveExtraProviderOptions(provider_types, provider_options, merged_options); - - InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options); - - return config_result; - }) - .def("get_state", [](PyTrainingSession* sess) { - NameMLValMap state_tensors; - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->GetStateTensors(state_tensors)); - auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager(); - // convert to numpy array - std::map rmap; - for (auto& kv : state_tensors) { - if (kv.second.IsTensor()) { - py::object obj; - const Tensor& rtensor = kv.second.Get(); - GetPyObjFromTensor(rtensor, obj, &data_transfer_manager); - rmap.insert({kv.first, obj}); - } else { - throw std::runtime_error("Non tensor type in session state tensors is not expected."); - } - } - return rmap; - }) - .def("get_model_state", [](PyTrainingSession* sess, bool include_mixed_precision_weights) { - std::unordered_map model_state_tensors; - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->GetModelState(model_state_tensors, include_mixed_precision_weights)); - auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager(); - return ConvertORTTensorMapToNumpy(model_state_tensors, data_transfer_manager); - }) - .def("get_optimizer_state", [](PyTrainingSession* sess) { - std::unordered_map opt_state_tensors; - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->GetOptimizerState(opt_state_tensors)); - auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager(); - return ConvertORTTensorMapToNumpy(opt_state_tensors, data_transfer_manager); - }) - .def("get_partition_info_map", [](PyTrainingSession* sess) { - std::unordered_map>> part_info_map; - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->GetPartitionInfoMap(part_info_map)); - return part_info_map; - }) - .def("load_state", [](PyTrainingSession* sess, std::unordered_map& state, bool strict) { - NameMLValMap state_tensors; - for (auto initializer : state) { - OrtValue ml_value; - auto px = sess->GetSessionHandle()->GetModelInputs(); - if (!px.first.IsOK() || !px.second) { - throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null"); - } - CreateGenericMLValue(px.second, GetAllocator(), initializer.first, initializer.second, &ml_value); - ThrowIfPyErrOccured(); - state_tensors.insert(std::make_pair(initializer.first, ml_value)); - } - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->SetStateTensors(state_tensors, strict)); - }) - .def("is_output_fp32_node", [](PyTrainingSession* sess, const std::string& output_name) { - return static_cast(sess->GetSessionHandle())->IsGraphOutputFp32Node(output_name); - }); - py::class_(m, "PartialGraphExecutionState") .def(py::init([]() { return std::make_unique(); diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index 88ef90a7feaa8..4d1db7334f280 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -40,7 +40,7 @@ const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM* void addGlobalMethods(py::module& m); void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn); -void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn); +void addObjectMethodsForTraining(py::module& m); void addObjectMethodsForEager(py::module& m); #ifdef ENABLE_LAZY_TENSOR void addObjectMethodsForLazyTensor(py::module& m); @@ -339,7 +339,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { } #endif - addObjectMethodsForTraining(m, ORTTrainingRegisterExecutionProviders); + addObjectMethodsForTraining(m); #ifdef ENABLE_LAZY_TENSOR addObjectMethodsForLazyTensor(m); diff --git a/orttraining/orttraining/python/training/__init__.py b/orttraining/orttraining/python/training/__init__.py index 73b1f826f68e1..a3c22686a1039 100644 --- a/orttraining/orttraining/python/training/__init__.py +++ b/orttraining/orttraining/python/training/__init__.py @@ -8,26 +8,16 @@ TrainingParameters, is_ortmodule_available, ) -from onnxruntime.capi.training.training_session import TrainingSession - # Options need to be imported before `ORTTrainer`. -from .orttrainer_options import ORTTrainerOptions -from .orttrainer import ORTTrainer, TrainStepInfo -from . import amp, artifacts, checkpoint, model_desc_validation, optim +from . import amp, artifacts, optim __all__ = [ "PropagateCastOpsStrategy", "TrainingParameters", "is_ortmodule_available", - "TrainingSession", - "ORTTrainerOptions", - "ORTTrainer", - "TrainStepInfo", "amp", "artifacts", - "checkpoint", - "model_desc_validation", "optim", ] diff --git a/orttraining/orttraining/python/training/_checkpoint_storage.py b/orttraining/orttraining/python/training/_checkpoint_storage.py deleted file mode 100644 index 7a8ada7dee96b..0000000000000 --- a/orttraining/orttraining/python/training/_checkpoint_storage.py +++ /dev/null @@ -1,107 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import pickle -from collections.abc import Mapping - -import h5py - - -def _dfs_save(group, save_obj): - """Recursively go over each level in the save_obj dictionary and save values to a hdf5 group""" - - for key, value in save_obj.items(): - if isinstance(value, Mapping): - subgroup = group.create_group(key) - _dfs_save(subgroup, value) - else: - group[key] = value - - -def save(save_obj: dict, path): - """Persists the input dictionary to a file specified by path. - - Saves an hdf5 representation of the save_obj dictionary to a file or a file-like object specified by path. - Values are saved in a format supported by h5py. For example, a PyTorch tensor is saved and loaded as a - numpy object. So, user types may be converted from their original types to numpy equivalent types. - - Args: - save_obj: dictionary that needs to be saved. - save_obj should consist of types supported by hdf5 file format. - if hdf5 does not recognize a type, an exception is raised. - if save_obj is not a dictionary, a ValueError is raised. - path: string representation to a file path or a python file-like object. - if file already exists at path, an exception is raised. - """ - if not isinstance(save_obj, Mapping): - raise ValueError("Object to be saved must be a dictionary") - - with h5py.File(path, "w-") as f: - _dfs_save(f, save_obj) - - -def _dfs_load(group, load_obj): - """Recursively go over each level in the hdf5 group and load the values into the given dictionary""" - - for key in group: - if isinstance(group[key], h5py.Group): - load_obj[key] = {} - _dfs_load(group[key], load_obj[key]) - else: - load_obj[key] = group[key][()] - - -def load(path, key=None): - """Loads the data stored in the binary file specified at the given path into a dictionary and returns it. - - Loads the data from an hdf5 file specified at the given path into a python dictionary. - Loaded dictionary contains numpy equivalents of python data types. For example: - PyTorch tensor -> saved as a numpy array and loaded as a numpy array. - bool -> saved as a numpy bool and loaded as a numpy bool - If a '/' separated key is provided, the value at that hierarchical level in the hdf5 group is returned. - - Args: - path: string representation to a file path or a python file-like object. - if file does not already exist at path, an exception is raised. - key: '/' separated representation of the hierarchy level value that needs to be returned/ - for example, if the saved binary file has structure {a: {b: x, c:y}} and the user would like - to query the value for c, the key provided should be 'a/c'. - the default value of None for key implies that the entire hdf5 file structure needs to be loaded into a dictionary and returned. - - Returns: - a dictionary loaded from the specified binary hdf5 file. - """ - if not h5py.is_hdf5(path): - raise ValueError(f"{path} is not an hdf5 file or a python file-like object.") - - load_obj = {} - with h5py.File(path, "r") as f: - if key: - f = f[key] # noqa: PLW2901 - if isinstance(f, h5py.Dataset): - return f[()] - - _dfs_load(f, load_obj) - - return load_obj - - -def to_serialized_hex(user_dict): - """Serialize the user_dict and convert the serialized bytes to a hex string and return""" - - return pickle.dumps(user_dict).hex() - - -def from_serialized_hex(serialized_hex): - """Convert serialized_hex to bytes and deserialize it and return""" - - # serialized_hex can be either a regular string or a byte string. - # if it is a byte string, convert to regular string using decode() - # if it is a regular string, do nothing to it - try: # noqa: SIM105 - serialized_hex = serialized_hex.decode() - except AttributeError: - pass - return pickle.loads(bytes.fromhex(serialized_hex)) diff --git a/orttraining/orttraining/python/training/_utils.py b/orttraining/orttraining/python/training/_utils.py index 4eb79443c8f1a..091274d1d171d 100644 --- a/orttraining/orttraining/python/training/_utils.py +++ b/orttraining/orttraining/python/training/_utils.py @@ -6,11 +6,9 @@ import importlib.util import os import sys -from functools import wraps # noqa: F401 import numpy as np import torch -from onnx import TensorProto # noqa: F401 from packaging.version import Version @@ -23,16 +21,6 @@ def get_device_index(device): return 0 if device.index is None else device.index -def get_device_index_from_input(input): - """Returns device index from a input PyTorch Tensor""" - - if isinstance(input, (list, tuple)): - device_index = get_device_index(input[0].device) - else: - device_index = get_device_index(input.device) - return device_index - - def get_device_str(device): if isinstance(device, str): # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 @@ -50,24 +38,6 @@ def get_device_str(device): return device -def get_all_gradients_finite_name_from_session(session): - """Find all_gradients_finite node on Session graph and return its name""" - - nodes = [x for x in session._outputs_meta if "all_gradients_finite" in x.name] - if len(nodes) != 1: - raise RuntimeError("'all_gradients_finite' node not found within training session") - return nodes[0].name - - -def get_gradient_accumulation_name_from_session(session): - """Find Group_Accumulated_Gradients node on Session graph and return its name""" - - nodes = [x for x in session._outputs_meta if "Group_Accumulated_Gradients" in x.name] - if len(nodes) != 1: - raise RuntimeError("'Group_Accumulated_Gradients' node not found within training session") - return nodes[0].name - - def dtype_torch_to_numpy(torch_dtype): """Converts PyTorch types to Numpy types @@ -232,111 +202,3 @@ def import_module_from_file(file_path, module_name=None): sys.modules[module_name] = module spec.loader.exec_module(module) return module - - -def state_dict_model_key(): - """Returns the model key name in the state dictionary""" - - return "model" - - -def state_dict_optimizer_key(): - """Returns the optimizer key name in the state dictionary""" - - return "optimizer" - - -def state_dict_partition_info_key(): - """Returns the partition info key name in the state dictionary""" - - return "partition_info" - - -def state_dict_trainer_options_key(): - """Returns the trainer options key name in the state dictionary""" - - return "trainer_options" - - -def state_dict_full_precision_key(): - """Returns the full precision key name in the state dictionary""" - - return "full_precision" - - -def state_dict_original_dimension_key(): - """Returns the original dimension key name in the state dictionary""" - - return "original_dim" - - -def state_dict_sharded_optimizer_keys(): - """Returns the optimizer key names that can be sharded in the state dictionary""" - - return {"Moment_1", "Moment_2"} - - -def state_dict_user_dict_key(): - """Returns the user dict key name in the state dictionary""" - - return "user_dict" - - -def state_dict_trainer_options_mixed_precision_key(): - """Returns the trainer options mixed precision key name in the state dictionary""" - - return "mixed_precision" - - -def state_dict_trainer_options_zero_stage_key(): - """Returns the trainer options zero_stage key name in the state dictionary""" - - return "zero_stage" - - -def state_dict_trainer_options_world_rank_key(): - """Returns the trainer options world_rank key name in the state dictionary""" - - return "world_rank" - - -def state_dict_trainer_options_world_size_key(): - """Returns the trainer options world_size key name in the state dictionary""" - - return "world_size" - - -def state_dict_trainer_options_data_parallel_size_key(): - """Returns the trainer options data_parallel_size key name in the state dictionary""" - - return "data_parallel_size" - - -def state_dict_trainer_options_horizontal_parallel_size_key(): - """Returns the trainer options horizontal_parallel_size key name in the state dictionary""" - - return "horizontal_parallel_size" - - -def state_dict_trainer_options_optimizer_name_key(): - """Returns the trainer options optimizer_name key name in the state dictionary""" - - return "optimizer_name" - - -def state_dict_train_step_info_key(): - """Returns the train step info key name in the state dictionary""" - - return "train_step_info" - - -def state_dict_train_step_info_optimization_step_key(): - """Returns the train step info optimization step key name in the state dictionary""" - - return "optimization_step" - - -def state_dict_train_step_info_step_key(): - """Returns the train step info step key name in the state dictionary""" - - return "step" diff --git a/orttraining/orttraining/python/training/checkpoint.py b/orttraining/orttraining/python/training/checkpoint.py deleted file mode 100644 index d0ff0650662b7..0000000000000 --- a/orttraining/orttraining/python/training/checkpoint.py +++ /dev/null @@ -1,748 +0,0 @@ -import os -import tempfile -import warnings -from enum import Enum - -import numpy as np -import onnx -import torch - -from . import _checkpoint_storage, _utils - -################################################################################ -# Experimental Checkpoint APIs -################################################################################ - - -def experimental_state_dict(ort_trainer, include_optimizer_state=True): - warnings.warn( - "experimental_state_dict() will be deprecated soon. Please use ORTTrainer.state_dict() instead.", - DeprecationWarning, - ) - - if not ort_trainer._training_session: - warnings.warn( - "ONNX Runtime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling state_dict()." - ) - return ort_trainer._state_dict - - # extract trained weights - session_state = ort_trainer._training_session.get_state() - torch_state = {} - for name in session_state: - torch_state[name] = torch.from_numpy(session_state[name]) - - # extract untrained weights and buffer - for n in ort_trainer._onnx_model.graph.initializer: - if n.name not in torch_state and n.name in ort_trainer.options.utils.frozen_weights: - torch_state[n.name] = torch.from_numpy(np.array(onnx.numpy_helper.to_array(n))) - - # Need to remove redundant (optimizer) initializers to map back to original torch state names - if not include_optimizer_state and ort_trainer._torch_state_dict_keys: - return {key: torch_state[key] for key in ort_trainer._torch_state_dict_keys if key in torch_state} - return torch_state - - -def experimental_load_state_dict(ort_trainer, state_dict, strict=False): - warnings.warn( - "experimental_load_state_dict() will be deprecated soon. Please use ORTTrainer.load_state_dict() instead.", - DeprecationWarning, - ) - - # Note: It may happen ONNX model has not yet been initialized - # In this case we cache a reference to desired state and delay the restore until after initialization - # Unexpected behavior will result if the user changes the reference before initialization - if not ort_trainer._training_session: - ort_trainer._state_dict = state_dict - ort_trainer._load_state_dict_strict = strict - return - - # Update onnx model from loaded state dict - cur_initializers_names = [n.name for n in ort_trainer._onnx_model.graph.initializer] - new_initializers = {} - - for name in state_dict: - if name in cur_initializers_names: - new_initializers[name] = state_dict[name].numpy() - elif strict: - raise RuntimeError(f"Checkpoint tensor: {name} is not present in the model.") - - ort_trainer._update_onnx_model_initializers(new_initializers) - - # create new session based on updated onnx model - ort_trainer._state_dict = None - ort_trainer._init_session() - - # load training state - session_state = {name: state_dict[name].numpy() for name in state_dict} - ort_trainer._training_session.load_state(session_state, strict) - - -def experimental_save_checkpoint( - ort_trainer, - checkpoint_dir, - checkpoint_prefix="ORT_checkpoint", - checkpoint_state_dict=None, - include_optimizer_state=True, -): - warnings.warn( - "experimental_save_checkpoint() will be deprecated soon. Please use ORTTrainer.save_checkpoint() instead.", - DeprecationWarning, - ) - - if checkpoint_state_dict is None: - checkpoint_state_dict = {"model": experimental_state_dict(ort_trainer, include_optimizer_state)} - else: - checkpoint_state_dict.update({"model": experimental_state_dict(ort_trainer, include_optimizer_state)}) - - assert os.path.exists(checkpoint_dir), f"checkpoint_dir ({checkpoint_dir}) directory doesn't exist" - - checkpoint_name = _get_checkpoint_name( - checkpoint_prefix, - ort_trainer.options.distributed.deepspeed_zero_optimization.stage, - ort_trainer.options.distributed.world_rank, - ort_trainer.options.distributed.world_size, - ) - checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) - if os.path.exists(checkpoint_file): - msg = f"{checkpoint_file} already exists, overwriting." - warnings.warn(msg) - torch.save(checkpoint_state_dict, checkpoint_file) - - -def experimental_load_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False): - warnings.warn( - "experimental_load_checkpoint() will be deprecated soon. Please use ORTTrainer.load_checkpoint() instead.", - DeprecationWarning, - ) - - checkpoint_files = _list_checkpoint_files(checkpoint_dir, checkpoint_prefix) - is_partitioned = False - if len(checkpoint_files) > 1: - msg = ( - f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." - " Attempting to load ZeRO checkpoint." - ) - warnings.warn(msg) - is_partitioned = True - if (not ort_trainer.options.distributed.deepspeed_zero_optimization.stage) and is_partitioned: - return _load_multi_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, strict) - else: - return _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict) - - -class _AGGREGATION_MODE(Enum): # noqa: N801 - Zero = 0 - Megatron = 1 - - -def _order_paths(paths, D_groups, H_groups): - """Reorders the given paths in order of aggregation of ranks for D and H parallellism respectively - and returns the ordered dict""" - - trainer_options_path_tuples = [] - world_rank = _utils.state_dict_trainer_options_world_rank_key() - - for path in paths: - trainer_options_path_tuples.append( - (_checkpoint_storage.load(path, key=_utils.state_dict_trainer_options_key()), path) - ) - - # sort paths according to rank - sorted_paths = [ - path - for _, path in sorted( - trainer_options_path_tuples, key=lambda trainer_options_path_pair: trainer_options_path_pair[0][world_rank] - ) - ] - - ordered_paths = dict() - ordered_paths["D"] = [[sorted_paths[i] for i in D_groups[group_id]] for group_id in range(len(D_groups))] - ordered_paths["H"] = [[sorted_paths[i] for i in H_groups[group_id]] for group_id in range(len(H_groups))] - - return ordered_paths - - -def _add_or_update_sharded_key( - state_key, state_value, state_sub_dict, model_state_key, state_partition_info, sharded_states_original_dims, mode -): - """Add or update the record for the sharded state_key in the state_sub_dict""" - - # record the original dimension for this state - original_dim = _utils.state_dict_original_dimension_key() - sharded_states_original_dims[model_state_key] = state_partition_info[original_dim] - - axis = 0 - if mode == _AGGREGATION_MODE.Megatron and state_partition_info["megatron_row_partition"] == 0: - axis = -1 - - if state_key in state_sub_dict: - # state_dict already contains a record for this state - # since this state is sharded, concatenate the state value to - # the record in the state_dict - state_sub_dict[state_key] = np.concatenate((state_sub_dict[state_key], state_value), axis) - else: - # create a new entry for this state in the state_dict - state_sub_dict[state_key] = state_value - - -def _add_or_validate_unsharded_key(state_key, state_value, state_sub_dict, mismatch_error_string): - """Add or validate the record for the unsharded state_key in the state_sub_dict""" - - if state_key in state_sub_dict: - # state_dict already contains a record for this unsharded state. - # assert that all values are the same for this previously loaded state - assert (state_sub_dict[state_key] == state_value).all(), mismatch_error_string - else: - # create a new entry for this state in the state_sub_dict - state_sub_dict[state_key] = state_value - - -def _aggregate_model_states( - rank_state_dict, sharded_states_original_dims, state_dict, mixed_precision_enabled, mode=_AGGREGATION_MODE.Zero -): - """Aggregates all model states from the rank_state_dict into state_dict""" - - model = _utils.state_dict_model_key() - full_precision = _utils.state_dict_full_precision_key() - partition_info = _utils.state_dict_partition_info_key() - - # if there are no model states in the rank_state_dict, no model aggregation is needed - if model not in rank_state_dict: - return - - if model not in state_dict: - state_dict[model] = {} - - if full_precision not in state_dict[model]: - state_dict[model][full_precision] = {} - - # iterate over all model state keys - for model_state_key, model_state_value in rank_state_dict[model][full_precision].items(): - # ZERO: full precision model states are sharded only when they exist in the partition_info subdict and mixed - # precision training was enabled. for full precision training, full precision model states are not sharded - # MEGATRON : full precision model states are sharded when they exist in the partition_info subdict - if (model_state_key in rank_state_dict[partition_info]) and ( - mode == _AGGREGATION_MODE.Megatron or mixed_precision_enabled - ): - # this model state is sharded - _add_or_update_sharded_key( - model_state_key, - model_state_value, - state_dict[model][full_precision], - model_state_key, - rank_state_dict[partition_info][model_state_key], - sharded_states_original_dims, - mode, - ) - else: - # this model state is not sharded since a record for it does not exist in the partition_info subdict - _add_or_validate_unsharded_key( - model_state_key, - model_state_value, - state_dict[model][full_precision], - f"Value mismatch for model state {model_state_key}", - ) - - -def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode=_AGGREGATION_MODE.Zero): - """Aggregates all optimizer states from the rank_state_dict into state_dict""" - - optimizer = _utils.state_dict_optimizer_key() - partition_info = _utils.state_dict_partition_info_key() - sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys() - - # if there are no optimizer states in the rank_state_dict, no optimizer aggregation is needed - if optimizer not in rank_state_dict: - return - - if optimizer not in state_dict: - state_dict[optimizer] = {} - - # iterate over all optimizer state keys - for model_state_key, optimizer_dict in rank_state_dict[optimizer].items(): - for optimizer_key, optimizer_value in optimizer_dict.items(): - if model_state_key not in state_dict[optimizer]: - state_dict[optimizer][model_state_key] = {} - - if optimizer_key in sharded_optimizer_keys and model_state_key in rank_state_dict[partition_info]: - # this optimizer state is sharded since a record exists in the partition_info subdict - _add_or_update_sharded_key( - optimizer_key, - optimizer_value, - state_dict[optimizer][model_state_key], - model_state_key, - rank_state_dict[partition_info][model_state_key], - sharded_states_original_dims, - mode, - ) - else: - # this optimizer state is not sharded since a record for it does not exist in the partition_info subdict - # or this optimizer key is not one of the sharded optimizer keys - _add_or_validate_unsharded_key( - optimizer_key, - optimizer_value, - state_dict[optimizer][model_state_key], - f"Value mismatch for model state {model_state_key} and optimizer state {optimizer_key}", - ) - - -def _reshape_states(sharded_states_original_dims, state_dict, mixed_precision_enabled): - """Reshape model and optimizer states in the state_dict according to dimensions in sharded_states_original_dims""" - - model = _utils.state_dict_model_key() - full_precision = _utils.state_dict_full_precision_key() - optimizer = _utils.state_dict_optimizer_key() - sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys() - - for sharded_state_key, original_dim in sharded_states_original_dims.items(): - # reshape model states to original_dim only when mixed precision is enabled - if mixed_precision_enabled and (model in state_dict): - state_dict[model][full_precision][sharded_state_key] = state_dict[model][full_precision][ - sharded_state_key - ].reshape(original_dim) - - # reshape optimizer states to original_dim - if optimizer in state_dict: - for optimizer_key, optimizer_value in state_dict[optimizer][sharded_state_key].items(): - if optimizer_key in sharded_optimizer_keys: - state_dict[optimizer][sharded_state_key][optimizer_key] = optimizer_value.reshape(original_dim) - - -def _aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation): - """Extracts trainer options from rank_state_dict and loads them accordingly on state_dict""" - trainer_options = _utils.state_dict_trainer_options_key() - state_dict[trainer_options] = {} - - mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key() - zero_stage = _utils.state_dict_trainer_options_zero_stage_key() - world_rank = _utils.state_dict_trainer_options_world_rank_key() - world_size = _utils.state_dict_trainer_options_world_size_key() - optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key() - D_size = _utils.state_dict_trainer_options_data_parallel_size_key() # noqa: N806 - H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() # noqa: N806 - - state_dict[trainer_options][mixed_precision] = rank_state_dict[trainer_options][mixed_precision] - state_dict[trainer_options][zero_stage] = 0 - state_dict[trainer_options][world_rank] = rank_state_dict[trainer_options][world_rank] if partial_aggregation else 0 - state_dict[trainer_options][world_size] = 1 - state_dict[trainer_options][optimizer_name] = rank_state_dict[trainer_options][optimizer_name] - state_dict[trainer_options][D_size] = 1 - state_dict[trainer_options][H_size] = 1 - - -def _aggregate_megatron_partition_info(rank_state_dict, state_dict): - """Extracts partition_info from rank_state_dict and loads on state_dict for megatron-partitioned weights""" - partition_info = _utils.state_dict_partition_info_key() - if partition_info not in state_dict: - state_dict[partition_info] = {} - - rank_partition_info = rank_state_dict[partition_info] - for model_state_key, partition_info_dict in rank_partition_info.items(): - if model_state_key not in state_dict[partition_info]: - # add partition info only if weight is megatron partitioned - if partition_info_dict["megatron_row_partition"] >= 0: - state_dict[partition_info][model_state_key] = partition_info_dict - - -def _to_pytorch_format(state_dict): - """Convert ORT state dictionary schema (hierarchical structure) to PyTorch state dictionary schema (flat structure)""" - - pytorch_state_dict = {} - for model_state_key, model_state_value in state_dict[_utils.state_dict_model_key()][ - _utils.state_dict_full_precision_key() - ].items(): - # convert numpy array to a torch tensor - pytorch_state_dict[model_state_key] = torch.tensor(model_state_value) - return pytorch_state_dict - - -def _get_parallellism_groups(data_parallel_size, horizontal_parallel_size, world_size): - """Returns the D and H groups for the given sizes""" - num_data_groups = world_size // data_parallel_size - data_groups = [] - for data_group_id in range(num_data_groups): - data_group_ranks = [] - for r in range(data_parallel_size): - data_group_ranks.append(data_group_id + horizontal_parallel_size * r) - data_groups.append(data_group_ranks) - - num_horizontal_groups = world_size // horizontal_parallel_size - horizontal_groups = [] - for hori_group_id in range(num_horizontal_groups): - hori_group_ranks = [] - for r in range(horizontal_parallel_size): - hori_group_ranks.append(hori_group_id * horizontal_parallel_size + r) - horizontal_groups.append(hori_group_ranks) - - return data_groups, horizontal_groups - - -def _aggregate_over_ranks( - ordered_paths, - ranks, - sharded_states_original_dims=None, - mode=_AGGREGATION_MODE.Zero, - partial_aggregation=False, - pytorch_format=True, -): - """Aggregate checkpoint files over set of ranks and return a single state dictionary - - Args: - ordered_paths: list of paths in the order in which they must be aggregated - ranks: list of ranks that are to be aggregated - sharded_states_original_dims: dict containing the original dims for sharded states that are persisted over - multiple calls to _aggregate_over_ranks() - mode: mode of aggregation: Zero or Megatron - partial_aggregation: boolean flag to indicate whether to produce a partially - aggregated state which can be further aggregated over - pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict - Returns: - state_dict that can be loaded into an ORTTrainer or into a PyTorch model - """ - state_dict = {} - if sharded_states_original_dims is None: - sharded_states_original_dims = dict() - world_rank = _utils.state_dict_trainer_options_world_rank_key() - mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key() - zero_stage = _utils.state_dict_trainer_options_zero_stage_key() - world_size = _utils.state_dict_trainer_options_world_size_key() - optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key() - - loaded_mixed_precision = None - loaded_world_size = None - loaded_zero_stage = None - loaded_optimizer_name = None - - for i, path in enumerate(ordered_paths): - rank_state_dict = _checkpoint_storage.load(path) - - assert _utils.state_dict_partition_info_key() in rank_state_dict, "Missing information: partition_info" - assert _utils.state_dict_trainer_options_key() in rank_state_dict, "Missing information: trainer_options" - assert ( - ranks[i] == rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank] - ), "Unexpected rank in file at path {}. Expected {}, got {}".format( - path, rank, rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank] # noqa: F821 - ) - if loaded_mixed_precision is None: - loaded_mixed_precision = rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] - else: - assert ( - loaded_mixed_precision == rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] - ), f"Mixed precision state mismatch among checkpoint files. File: {path}" - if loaded_world_size is None: - loaded_world_size = rank_state_dict[_utils.state_dict_trainer_options_key()][world_size] - else: - assert ( - loaded_world_size == rank_state_dict[_utils.state_dict_trainer_options_key()][world_size] - ), f"World size state mismatch among checkpoint files. File: {path}" - if loaded_zero_stage is None: - loaded_zero_stage = rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage] - else: - assert ( - loaded_zero_stage == rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage] - ), f"Zero stage mismatch among checkpoint files. File: {path}" - if loaded_optimizer_name is None: - loaded_optimizer_name = rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] - else: - assert ( - loaded_optimizer_name == rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] - ), f"Optimizer name mismatch among checkpoint files. File: {path}" - - # aggregate all model states - _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, loaded_mixed_precision, mode) - - if not pytorch_format: - # aggregate all optimizer states if pytorch_format is False - _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode) - - # for D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups - # to aggregate over Zero, and another pass to aggregate Megatron partitioned - # states. Preserve the relevant partition info only for weights that are megatron partitioned for - # a partial aggregation call - if partial_aggregation: - _aggregate_megatron_partition_info(rank_state_dict, state_dict) - - # entry for trainer_options in the state_dict to perform other sanity checks - if _utils.state_dict_trainer_options_key() not in state_dict: - _aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation) - - # entry for user_dict in the state_dict if not already present - if ( - _utils.state_dict_user_dict_key() not in state_dict - and _utils.state_dict_user_dict_key() in rank_state_dict - ): - state_dict[_utils.state_dict_user_dict_key()] = rank_state_dict[_utils.state_dict_user_dict_key()] - - # for a partial aggregation scenario, we might not have the entire tensor aggregated yet, thus skip reshape - if not partial_aggregation: - # reshape all the sharded tensors based on the original dimensions stored in sharded_states_original_dims - _reshape_states(sharded_states_original_dims, state_dict, loaded_mixed_precision) - - # return a flat structure for PyTorch model in case pytorch_format is True - # else return the hierarchical structure for ORTTrainer - return _to_pytorch_format(state_dict) if pytorch_format else state_dict - - -def _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format): # noqa: N802 - """Aggregate checkpoint files and return a single state dictionary for the D+H - (Zero+Megatron) partitioning strategy. - For D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups - to aggregate over Zero, and another pass over the previously aggregated states - to aggregate Megatron partitioned states. - """ - sharded_states_original_dims = {} - aggregate_data_checkpoint_files = [] - - # combine for Zero over data groups and save to temp file - with tempfile.TemporaryDirectory() as save_dir: - for group_id, d_group in enumerate(D_groups): - aggregate_state_dict = _aggregate_over_ranks( - ordered_paths["D"][group_id], - d_group, - sharded_states_original_dims, - partial_aggregation=True, - pytorch_format=False, - ) - - filename = "ort.data_group." + str(group_id) + ".ort.pt" - filepath = os.path.join(save_dir, filename) - _checkpoint_storage.save(aggregate_state_dict, filepath) - aggregate_data_checkpoint_files.append(filepath) - - assert len(aggregate_data_checkpoint_files) > 0 - - # combine for megatron: - aggregate_state = _aggregate_over_ranks( - aggregate_data_checkpoint_files, - H_groups[0], - sharded_states_original_dims, - mode=_AGGREGATION_MODE.Megatron, - pytorch_format=pytorch_format, - ) - - return aggregate_state - - -def aggregate_checkpoints(paths, pytorch_format=True): - """Aggregate checkpoint files and return a single state dictionary - - Aggregates checkpoint files specified by paths and loads them one at a time, merging - them into a single state dictionary. - The checkpoint files represented by paths must be saved through ORTTrainer.save_checkpoint() function. - The schema of the state_dict returned will be in the same as the one returned by ORTTrainer.state_dict() - - Args: - paths: list of more than one file represented as strings where the checkpoint is saved - pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict - Returns: - state_dict that can be loaded into an ORTTrainer or into a PyTorch model - """ - - loaded_trainer_options = _checkpoint_storage.load(paths[0], key=_utils.state_dict_trainer_options_key()) - D_size = _utils.state_dict_trainer_options_data_parallel_size_key() # noqa: N806 - H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() # noqa: N806 - world_size = _utils.state_dict_trainer_options_world_size_key() - - D_size = loaded_trainer_options[D_size] # noqa: N806 - H_size = loaded_trainer_options[H_size] # noqa: N806 - world_size = loaded_trainer_options[world_size] - D_groups, H_groups = _get_parallellism_groups(D_size, H_size, world_size) # noqa: N806 - - combine_zero = loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 - combine_megatron = len(H_groups[0]) > 1 - - # order the paths in the order of groups in which they must be aggregated according to - # data-parallel groups and H-parallel groups obtained - # eg: {'D': [[path_0, path_2],[path_1, path_3]], 'H': [[path_0, path_1],[path_2, path_3]]} - ordered_paths = _order_paths(paths, D_groups, H_groups) - - aggregate_state = None - if combine_zero and combine_megatron: - aggregate_state = _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format) - elif combine_zero: - aggregate_state = _aggregate_over_ranks( - ordered_paths["D"][0], D_groups[0], mode=_AGGREGATION_MODE.Zero, pytorch_format=pytorch_format - ) - elif combine_megatron: - aggregate_state = _aggregate_over_ranks( - ordered_paths["H"][0], H_groups[0], mode=_AGGREGATION_MODE.Megatron, pytorch_format=pytorch_format - ) - - return aggregate_state - - -################################################################################ -# Helper functions -################################################################################ - - -def _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict): - checkpoint_name = _get_checkpoint_name( - checkpoint_prefix, - is_partitioned, - ort_trainer.options.distributed.world_rank, - ort_trainer.options.distributed.world_size, - ) - checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) - - if is_partitioned: - assert_msg = ( - f"Couldn't find checkpoint file {checkpoint_file}." - " Optimizer partitioning is enabled using ZeRO. Please make sure the checkpoint file exists " - f"for rank {ort_trainer.options.distributed.world_rank} of {ort_trainer.options.distributed.world_size}" - ) - else: - assert_msg = f"Couldn't find checkpoint file {checkpoint_file}." - assert os.path.exists(checkpoint_file), assert_msg - - checkpoint_state = torch.load(checkpoint_file, map_location="cpu") - experimental_load_state_dict(ort_trainer, checkpoint_state["model"], strict=strict) - del checkpoint_state["model"] - return checkpoint_state - - -def _load_multi_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, strict): - checkpoint_files = _list_checkpoint_files(checkpoint_dir, checkpoint_prefix) - - ckpt_agg = _CombineZeroCheckpoint(checkpoint_files) - aggregate_state_dict = ckpt_agg.aggregate_checkpoints() - - experimental_load_state_dict(ort_trainer, aggregate_state_dict, strict=strict) - - # aggregate other keys in the state_dict. - # Values will be overwritten for matching keys among workers - all_checkpoint_states = dict() - for checkpoint_file in checkpoint_files: - checkpoint_state = torch.load(checkpoint_file, map_location="cpu") - del checkpoint_state["model"] - all_checkpoint_states.update(checkpoint_state) - return all_checkpoint_states - - -def _list_checkpoint_files(checkpoint_dir, checkpoint_prefix, extension=".ort.pt"): - ckpt_file_names = [f for f in os.listdir(checkpoint_dir) if f.startswith(checkpoint_prefix)] - ckpt_file_names = [f for f in ckpt_file_names if f.endswith(extension)] - ckpt_file_names = [os.path.join(checkpoint_dir, f) for f in ckpt_file_names] - - assert len(ckpt_file_names) > 0, f"No checkpoint found with prefix '{checkpoint_prefix}' at '{checkpoint_dir}'" - return ckpt_file_names - - -def _get_checkpoint_name(prefix, is_partitioned, world_rank=None, world_size=None): - SINGLE_CHECKPOINT_FILENAME = "{prefix}.ort.pt" # noqa: N806 - MULTIPLE_CHECKPOINT_FILENAME = "{prefix}.ZeRO.{world_rank}.{world_size}.ort.pt" # noqa: N806 - - if is_partitioned: - filename = MULTIPLE_CHECKPOINT_FILENAME.format( - prefix=prefix, world_rank=world_rank, world_size=(world_size - 1) - ) - else: - filename = SINGLE_CHECKPOINT_FILENAME.format(prefix=prefix) - return filename - - -def _split_state_dict(state_dict): - optimizer_keys = ["Moment_1_", "Moment_2_", "Update_Count_", "Step"] - split_sd = {"optimizer": {}, "fp32_param": {}, "fp16_param": {}} - for k, v in state_dict.items(): - mode = "fp32_param" - for optim_key in optimizer_keys: - if k.startswith(optim_key): - mode = "optimizer" - break - if k.endswith("_fp16"): - mode = "fp16_param" - split_sd[mode][k] = v - return split_sd - - -class _CombineZeroCheckpoint: - def __init__(self, checkpoint_files, clean_state_dict=None): - assert len(checkpoint_files) > 0, "No checkpoint files passed" - self.checkpoint_files = checkpoint_files - self.clean_state_dict = clean_state_dict - self.world_size = int(self.checkpoint_files[0].split("ZeRO")[1].split(".")[2]) + 1 - assert len(self.checkpoint_files) == self.world_size, f"Could not find {self.world_size} files" - self.weight_shape_map = {} - self.sharded_params = set() - - def _split_name(self, name: str): - name_split = name.split("_view_") - view_num = None - if len(name_split) > 1: - view_num = int(name_split[1]) - optimizer_key = "" - mp_suffix = "" - if name_split[0].startswith("Moment_1"): - optimizer_key = "Moment_1_" - elif name_split[0].startswith("Moment_2"): - optimizer_key = "Moment_2_" - elif name_split[0].startswith("Update_Count"): - optimizer_key = "Update_Count_" - elif name_split[0].endswith("_fp16"): - mp_suffix = "_fp16" - param_name = name_split[0] - if optimizer_key: - param_name = param_name.split(optimizer_key)[1] - param_name = param_name.split("_fp16")[0] - return param_name, optimizer_key, view_num, mp_suffix - - def _update_weight_statistics(self, name, value): - if name not in self.weight_shape_map: - self.weight_shape_map[name] = value.size() # original shape of tensor - - def _reshape_tensor(self, key): - value = self.aggregate_state_dict[key] - weight_name, _, _, _ = self._split_name(key) - set_size = self.weight_shape_map[weight_name] - self.aggregate_state_dict[key] = value.reshape(set_size) - - def _aggregate(self, param_dict): - for k, v in param_dict.items(): - weight_name, optimizer_key, view_num, mp_suffix = self._split_name(k) - if view_num is not None: - # parameter is sharded - param_name = optimizer_key + weight_name + mp_suffix - - if param_name in self.aggregate_state_dict and optimizer_key not in ["Update_Count_"]: - self.sharded_params.add(param_name) - # Found a previous shard of the param, concatenate shards ordered by ranks - self.aggregate_state_dict[param_name] = torch.cat((self.aggregate_state_dict[param_name], v)) - else: - self.aggregate_state_dict[param_name] = v - else: - if k in self.aggregate_state_dict: - assert (self.aggregate_state_dict[k] == v).all(), "Unsharded params must have the same value" - else: - self.aggregate_state_dict[k] = v - self._update_weight_statistics(weight_name, v) - - def aggregate_checkpoints(self): - warnings.warn( - "_CombineZeroCheckpoint.aggregate_checkpoints() will be deprecated soon. " - "Please use aggregate_checkpoints() instead.", - DeprecationWarning, - ) - - checkpoint_prefix = self.checkpoint_files[0].split(".ZeRO")[0] - self.aggregate_state_dict = dict() - - for i in range(self.world_size): - checkpoint_name = _get_checkpoint_name(checkpoint_prefix, True, i, self.world_size) - rank_state_dict = torch.load(checkpoint_name, map_location=torch.device("cpu")) - if "model" in rank_state_dict: - rank_state_dict = rank_state_dict["model"] - - if self.clean_state_dict: - rank_state_dict = self.clean_state_dict(rank_state_dict) - - rank_state_dict = _split_state_dict(rank_state_dict) - self._aggregate(rank_state_dict["fp16_param"]) - self._aggregate(rank_state_dict["fp32_param"]) - self._aggregate(rank_state_dict["optimizer"]) - - for k in self.sharded_params: - self._reshape_tensor(k) - return self.aggregate_state_dict diff --git a/orttraining/orttraining/python/training/model_desc_validation.py b/orttraining/orttraining/python/training/model_desc_validation.py deleted file mode 100644 index dd3f4cb95cd59..0000000000000 --- a/orttraining/orttraining/python/training/model_desc_validation.py +++ /dev/null @@ -1,408 +0,0 @@ -from collections import namedtuple - -import cerberus -import torch - -from ._utils import static_vars - -LEARNING_RATE_IO_DESCRIPTION_NAME = "__learning_rate" -ALL_FINITE_IO_DESCRIPTION_NAME = "__all_finite" -LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME = "__loss_scale_input_name" -GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME = "__gradient_accumulation_name" - - -class _ORTTrainerModelDesc: - def __init__(self, model_desc): - # Keep a copy of original input for debug - self._original = dict(model_desc) - - # Global counter used to validate occurrences of 'is_loss=True' whithin 'model_desc.outputs' - # A stateless validator is used for each tuple, but validation accross the whole list of tuple is needed - # because just one 'is_loss=True' is allowed withing 'model_desc.outputs' list of tuples - _model_desc_outputs_validation.loss_counter = 0 - - # Used for logging purposes - self._main_class_name = self.__class__.__name__ - - # Validates user input - self._validated = dict(self._original) - validator = cerberus.Validator(MODEL_DESC_SCHEMA) - self._validated = validator.validated(self._validated) - if self._validated is None: - raise ValueError(f"Invalid model_desc: {validator.errors}") - - # Normalize inputs to a list of namedtuple(name, shape) - self._InputDescription = namedtuple("InputDescription", ["name", "shape"]) - self._InputDescriptionTyped = namedtuple("InputDescriptionTyped", ["name", "shape", "dtype"]) - for idx, input in enumerate(self._validated["inputs"]): - self._validated["inputs"][idx] = self._InputDescription(*input) - - # Normalize outputs to a list of namedtuple(name, shape, is_loss) - self._OutputDescription = namedtuple("OutputDescription", ["name", "shape", "is_loss"]) - self._OutputDescriptionTyped = namedtuple( - "OutputDescriptionTyped", ["name", "shape", "is_loss", "dtype", "dtype_amp"] - ) - for idx, output in enumerate(self._validated["outputs"]): - if len(output) == 2: - self._validated["outputs"][idx] = self._OutputDescription(*output, False) - else: - self._validated["outputs"][idx] = self._OutputDescription(*output) - - # Hard-code learning rate, all_finite descriptors - self.learning_rate = self._InputDescriptionTyped(LEARNING_RATE_IO_DESCRIPTION_NAME, [1], torch.float32) - - # Convert dict in object - for k, v in self._validated.items(): - setattr(self, k, self._wrap(v)) - - def __repr__(self): - """Pretty representation for a model description class""" - - pretty_msg = "Model description:\n" - - # Inputs - inputs = [] - for i_desc in self.inputs: - if isinstance(i_desc, self._InputDescription): - inputs.append(f"(name={i_desc.name}, shape={i_desc.shape})") - elif isinstance(i_desc, self._InputDescriptionTyped): - inputs.append(f"(name={i_desc.name}, shape={i_desc.shape}, dtype={i_desc.dtype})") - else: - raise ValueError(f"Unexpected type {type(i_desc)} for input description") - - pretty_msg += "\nInputs:" - for idx, item in enumerate(inputs): - pretty_msg += f"\n\t{idx}: {item}" - - # Outputs - outputs = [] - for o_desc in self.outputs: - if isinstance(o_desc, self._OutputDescription): - outputs.append(f"(name={o_desc.name}, shape={o_desc.shape})") - elif isinstance(o_desc, self._OutputDescriptionTyped): - outputs.append( - f"(name={o_desc.name}, shape={o_desc.shape}, dtype={o_desc.dtype}, dtype_amp={o_desc.dtype_amp})" - ) - else: - raise ValueError(f"Unexpected type {type(o_desc)} for output description") - pretty_msg += "\nOutputs:" - for idx, item in enumerate(outputs): - pretty_msg += f"\n\t{idx}: {item}" - - # Learning rate - if self.learning_rate: - pretty_msg += "\nLearning rate: " - pretty_msg += ( - f"(name={self.learning_rate.name}, shape={self.learning_rate.shape}, dtype={self.learning_rate.dtype})" - ) - - # Mixed precision - if getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None) or getattr( - self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None - ): - pretty_msg += "\nMixed Precision:" - if getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None): - pretty_msg += "\n\tis gradients finite: " - pretty_msg += ( - f"(name={self.all_finite.name}, shape={self.all_finite.shape}, dtype={self.all_finite.dtype})" - ) - if getattr(self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None): - pretty_msg += "\n\tloss scale input name: " - pretty_msg += f"(name={self.loss_scale_input.name}, shape={self.loss_scale_input.shape}, dtype={self.loss_scale_input.dtype})" - - # Gradient Accumulation steps - if self.gradient_accumulation: - pretty_msg += "\nGradient Accumulation: " - pretty_msg += f"(name={self.gradient_accumulation.name}, shape={self.gradient_accumulation.shape}, dtype={self.gradient_accumulation.dtype})" - - return pretty_msg - - def add_type_to_input_description(self, index, dtype): - """Updates an existing input description at position 'index' with 'dtype' type information - - Args: - index (int): position within 'inputs' description - dtype (torch.dtype): input data type - """ - - assert isinstance(index, int) and index >= 0, "input 'index' must be a positive int" - assert isinstance(dtype, torch.dtype), "input 'dtype' must be a torch.dtype type" - existing_values = (*self.inputs[index],) - if isinstance(self.inputs[index], self._InputDescriptionTyped): - existing_values = (*existing_values[:-1],) - self.inputs[index] = self._InputDescriptionTyped(*existing_values, dtype) - - def add_type_to_output_description(self, index, dtype, dtype_amp=None): - """Updates an existing output description at position 'index' with 'dtype' type information - - Args: - index (int): position within 'inputs' description - dtype (torch.dtype): input data type - dtype_amp (torch.dtype, default is None): input data type for evaluation with mixed precision - """ - - assert isinstance(index, int) and index >= 0, "output 'index' must be a positive int" - assert isinstance(dtype, torch.dtype), "output 'dtype' must be a torch.dtype type" - assert dtype_amp is None or isinstance( - dtype_amp, torch.dtype - ), "output 'dtype_amp' must be either None or torch.dtype type" - existing_values = (*self.outputs[index],) - if isinstance(self.outputs[index], self._OutputDescriptionTyped): - existing_values = (*existing_values[:-2],) - self.outputs[index] = self._OutputDescriptionTyped(*existing_values, dtype, dtype_amp) - - @property - def gradient_accumulation(self): - return getattr(self, GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME, None) - - @gradient_accumulation.setter - def gradient_accumulation(self, name): - self._add_output_description( - self, name, [1], False, torch.bool, None, GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME, ignore_duplicate=True - ) - - @property - def all_finite(self): - return getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None) - - @all_finite.setter - def all_finite(self, name): - self._add_output_description( - self, name, [1], False, torch.bool, None, ALL_FINITE_IO_DESCRIPTION_NAME, ignore_duplicate=True - ) - - @property - def loss_scale_input(self): - return getattr(self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None) - - @loss_scale_input.setter - def loss_scale_input(self, name): - self._add_input_description( - self, name, [], torch.float32, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, ignore_duplicate=True - ) - - def _add_input_description(self, node, name, shape, dtype=None, attr_name=None, ignore_duplicate=False): - """Add a new input description into the node object - - If 'dtype' is specified, a typed input description namedtuple(name, shape, dtype) is created. - Otherwise an untyped input description namedtuple(name, shape) is created instead. - - Args: - node (list or object): node to append input description to. When 'node' is 'self.inputs', - a new input description is appended to the list. - Otherwise, a new input description is created as an attribute into 'node' with name 'attr_name' - name (str): name of input description - shape (list): shape of input description - dtype (torch.dtype): input data type - attr_name (str, default is None): friendly name to allow direct access to the output description - ignore_duplicate (bool, default is False): silently skips addition of duplicate inputs - """ - - assert isinstance(name, str) and len(name) > 0, "'name' is an invalid input name" - not_found = True - if not ignore_duplicate: - if id(node) == id(self.inputs): - not_found = all([name not in i_desc.name for i_desc in node]) - assert not_found, f"'name' {name} already exists in the inputs description" - else: - not_found = attr_name not in dir(self) - assert not_found, f"'attr_name' {attr_name} already exists in the 'node'" - elif not not_found: - return - assert isinstance(shape, list) and all( - [(isinstance(dim, int) or (isinstance(dim, str) and len(dim) > 0)) for dim in shape] - ), "'shape' must be a list of int or str with length at least 1" - assert dtype is None or isinstance(dtype, torch.dtype), "'dtype' must be either None or a torch.dtype type" - if dtype: - new_input_desc = self._InputDescriptionTyped(name, shape, dtype) - else: - new_input_desc = self._InputDescription(name, shape) - - if id(node) == id(self.inputs): - self.inputs.append(new_input_desc) - else: - assert isinstance(attr_name, str) and len(attr_name) > 0, "Invalid 'attr_name'" - setattr(node, attr_name, new_input_desc) - - def _add_output_description( - self, node, name, shape, is_loss, dtype=None, dtype_amp=None, attr_name=None, ignore_duplicate=False - ): - """Add a new output description into the node object as a tuple - - When (name, shape, is_loss, dtype) is specified, a typed output description is created - Otherwise an untyped output description (name, shape, is_loss) is created instead - - Args: - node (list or object): node to append output description to. When 'node' is 'self.outputs', - a new output description is appended to the list. - Otherwise, a new output description is created as an attribute into 'node' with name 'attr_name' - name (str): name of output description - shape (list): shape of output description - is_loss (bool): specifies whether this output is a loss - dtype (torch.dtype): input data type - dtype_amp (torch.dtype, default is None): input data type for evaluation with mixed precision. - attr_name (str, default is None): friendly name to allow direct access to the output description - ignore_duplicate (bool, default is False): silently skips addition of duplicate outputs - """ - - assert isinstance(name, str) and len(name) > 0, "'name' is an invalid output name" - assert isinstance(shape, list) and all( - [(isinstance(dim, int) or (isinstance(dim, str) and len(dim) > 0)) for dim in shape] - ), "'shape' must be a list of int or str with length at least 1" - assert isinstance(is_loss, bool), "'is_loss' must be a bool" - - not_found = True - if not ignore_duplicate: - if id(node) == id(self.outputs): - not_found = all([name not in o_desc.name for o_desc in node]) - assert not_found, f"'name' {name} already exists in the outputs description" - assert ( - all([not o_desc.is_loss for o_desc in node]) if is_loss else True - ), "Only one 'is_loss' is supported at outputs description" - else: - not_found = attr_name not in dir(self) - assert not_found, f"'attr_name' {attr_name} already exists in the 'node'" - elif not not_found: - return - - assert dtype is None or isinstance(dtype, torch.dtype), "'dtype' must be either None or a torch.dtype type" - if dtype: - new_output_desc = self._OutputDescriptionTyped(name, shape, is_loss, dtype, None) - else: - new_output_desc = self._OutputDescription(name, shape, is_loss) - - if id(node) == id(self.outputs): - self.outputs.append(new_output_desc) - else: - assert isinstance(attr_name, str) and len(attr_name) > 0, "Invalid 'attr_name'" - setattr(node, attr_name, new_output_desc) - - def _wrap(self, v): - """Add 'v' as self's attribute to allow direct access as self.v""" - if isinstance(v, (list)): - return type(v)([self._wrap(v) for v in v]) - elif isinstance( - v, - ( - self._InputDescription, - self._InputDescriptionTyped, - self._OutputDescription, - self._OutputDescriptionTyped, - ), - ): - return v - elif isinstance(v, (tuple)): - return type(v)([self._wrap(v) for v in v]) - elif isinstance(v, (dict, int, float, bool, str)): - return _ORTTrainerModelDescInternal(self._main_class_name, v) if isinstance(v, dict) else v - else: - raise ValueError( - f"Unsupported type for model_desc ({v})." - "Only int, float, bool, str, list, tuple and dict are supported" - ) - - -class _ORTTrainerModelDescInternal(_ORTTrainerModelDesc): - r"""Internal class used by ONNX Runtime training backend for input validation - - NOTE: Users MUST NOT use this class in any way! - """ - - def __init__(self, main_class_name, model_desc): - # Used for logging purposes - self._main_class_name = main_class_name - - # Convert dict in object - for k, v in dict(model_desc).items(): - setattr(self, k, self._wrap(v)) - - -def _model_desc_inputs_validation(field, value, error): - r"""Cerberus custom check method for 'model_desc.inputs' - - 'model_desc.inputs' is a list of tuples. - The list has variable length, but each tuple has size 2 - - The first element of the tuple is a string which represents the input name - The second element is a list of shapes. Each shape must be either an int or string. - Empty list represents a scalar output - - Validation is done within each tuple to enforce the schema described above. - - Example: - - .. code-block:: python - - model_desc['inputs'] = [('input1', ['batch', 1024]), - ('input2', []) - ('input3', [512])] - """ - - if not isinstance(value, tuple) or len(value) != 2: - error(field, "must be a tuple with size 2") - if not isinstance(value[0], str): - error(field, "the first element of the tuple (aka name) must be a string") - if not isinstance(value[1], list): - error(field, "the second element of the tuple (aka shape) must be a list") - else: - for shape in value[1]: - if not isinstance(shape, str) and not isinstance(shape, int) or isinstance(shape, bool): - error(field, "each shape must be either a string or integer") - - -@static_vars(loss_counter=0) -def _model_desc_outputs_validation(field, value, error): - r"""Cerberus custom check method for 'model_desc.outputs' - - 'model_desc.outputs' is a list of tuples with variable length. - The first element of the tuple is a string which represents the output name - The second element is a list of shapes. Each shape must be either an int or string. - Empty list represents a scalar output - The third element is optional and is a flag that signals whether the output is a loss value - - Validation is done within each tuple to enforce the schema described above, but also - throughout the list of tuples to ensure a single 'is_loss=True' occurrence. - - Example: - - .. code-block:: python - - model_desc['outputs'] = [('output1', ['batch', 1024], is_loss=True), - ('output2', [], is_loss=False) - ('output3', [512])] - """ - - if not isinstance(value, tuple) or len(value) < 2 or len(value) > 3: - error(field, "must be a tuple with size 2 or 3") - if len(value) == 3 and not isinstance(value[2], bool): - error(field, "the third element of the tuple (aka is_loss) must be a boolean") - elif len(value) == 3: - if value[2]: - _model_desc_outputs_validation.loss_counter += 1 - if _model_desc_outputs_validation.loss_counter > 1: - error(field, "only one is_loss can bet set to True") - if not isinstance(value[0], str): - error(field, "the first element of the tuple (aka name) must be a string") - if not isinstance(value[1], list): - error(field, "the second element of the tuple (aka shape) must be a list") - else: - for shape in value[1]: - if not isinstance(shape, str) and not isinstance(shape, int) or isinstance(shape, bool): - error(field, "each shape must be either a string or integer") - - -# Validation schema for model description dictionary -MODEL_DESC_SCHEMA = { - "inputs": { - "type": "list", - "required": True, - "minlength": 1, - "schema": {"check_with": _model_desc_inputs_validation}, - }, - "outputs": { - "type": "list", - "required": True, - "minlength": 1, - "schema": {"check_with": _model_desc_outputs_validation}, - }, -} diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 4977272de5ac9..8efbe16d7d61d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -412,14 +412,24 @@ def _matmul4bit_export(g, n, *args, **kwargs): return None quant_state = args[4] - absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state + if isinstance(quant_state, list): + # version <= 0.41.1 + absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state + nested = compressed_stats is not None + else: + # version > 0.41.1 + absmax = quant_state.absmax + shape = quant_state.shape + blocksize = quant_state.blocksize + nested = quant_state.nested + quant_type = quant_state.quant_type # MatMulBnb4's blocksize needs to be a power of 2 and not smaller than 16 if blocksize < 16 or blocksize & (blocksize - 1) != 0: return None # MatMulBnb4 does not support double de-quantization (e.g. absmax is int, needs to be dequantized too) - if compressed_stats is not None: + if nested: return None # The PyTorch linear weight shape is [out_feature, in_feature] diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py deleted file mode 100644 index d5a488c436a1d..0000000000000 --- a/orttraining/orttraining/python/training/orttrainer.py +++ /dev/null @@ -1,1537 +0,0 @@ -import copy -import io -import os -import warnings -from functools import partial -from inspect import signature - -import numpy as np -import onnx -import torch - -import onnxruntime as ort -from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference - -from . import _checkpoint_storage, _utils, amp, checkpoint, optim, postprocess -from .model_desc_validation import _ORTTrainerModelDesc -from .orttrainer_options import ORTTrainerOptions - - -class TrainStepInfo: - r"""Private class used to store runtime information from current train step. - - After every train step, :py:meth:`ORTTrainer.train_step` updates the internal instance of - :py:class:`.TrainStepInfo` residing on :py:class:`.ORTTrainer` with relevant information - from the forward pass. - - This class shouldn't be accessed directly by the user, unless they really know what they are doing. - Instead, :py:class:`.ORTTrainer` passes it to relevant class methods automatically, - such as :py:method:`._LRScheduler.get_lr` or :py:class:`.LossScaler.update`. - - Args: - optimizer_config (optim._OptimizerConfig): reference to optimizer config - all_finite (bool, default is True): flag that indicates whether all gradients are still finite after last step - fetches (list of str, default is []): list of output names to fetch from train_step/eval_step. Set it to [] to reset normal behavior. - optimization_step (int): indicates the number of optimizations performed. Used for learning rate scheduling - step (int): indicates current training step. Used for gradient accumulation - - Example: - - .. code-block:: python - - info = TrainStepInfo(optimizer_config=optim.SGDConfig(lr=0.01)) - if info.all_finite: - print(f'Yay, all gradients are finite at {step} step!') - - """ - - def __init__(self, optimizer_config, all_finite=True, fetches=[], optimization_step=0, step=0): # noqa: B006 - assert isinstance(optimizer_config, optim._OptimizerConfig), "optimizer_config must be a optim._OptimizerConfig" - assert isinstance(all_finite, bool), "all_finite must be a bool" - assert isinstance(fetches, list) and all( - [isinstance(item, str) for item in fetches] - ), "fetches must be a list of str" - assert isinstance(optimization_step, int) and optimization_step >= 0, "optimization_step must be a positive int" - assert isinstance(step, int) and step >= 0, "step must be a positive int" - - self.optimizer_config = optimizer_config - self.all_finite = all_finite - self.fetches = fetches - self.optimization_step = optimization_step - self.step = step - - -class ORTTrainer: - r"""Pytorch frontend for ONNX Runtime training - - Entry point that exposes the C++ backend of ORT as a Pytorch frontend. - - Args: - model (torch.nn.Module or onnx.ModelProto): either a PyTorch or ONNX model. - When a PyTorch model and :py:attr:`loss_fn` are specified, :py:attr:`model` and :py:obj:`loss_fn` are combined. - When a ONNX model is provided, the loss is identified by the flag :py:obj:`is_loss=True` in one of the :py:attr:`.model_desc.outputs` entries. - model_desc (dict): model input and output description. - This is used to identify inputs and outputs and their shapes, so that ORT can generate back propagation graph, plan memory allocation for - training, and perform optimizations. - :py:attr:`model_desc` must be consistent with the training :py:attr:`model` and have the following (:py:obj:`dict`) schema - :py:obj:`{ 'inputs': [tuple(name, shape)], 'outputs': [tuple(name, shape, is_loss)]}`. - :py:attr:`name` is a string representing the name of input or output of the model. - For :py:obj:`model_desc['inputs']` entries, :py:attr:`name` must match input names of the original PyTorch model's :py:meth:`torch.nn.Module.forward` method. - For ONNX models, both name and order of input names must match. - For :py:obj:`model_desc['outputs']` entries, the order must match the original PyTorch's output as returned by :py:meth:`torch.nn.Module.forward` method. - For ONNX models, both name and order of output names must match. - :py:attr:`shape` is a list of string or integers that describes the shape of the input/output. - Each dimension size can be either a string or an int. String means the dimension size is dynamic, while integers mean static dimensions. - An empty list implies a scalar. - Lastly, :py:attr:`is_loss` is a boolean (default is False) that flags if this output is considered a loss. - ORT backend needs to know which output is loss in order to generate back propagation graph. - Loss output must be specified when either :py:attr:`loss_fn` is specified or when loss is embedded in the model. - Note that only one loss output is supported per model. - optimizer_config (optim._OptimizerConfig): optimizer config. - One of :py:class:`.optim.AdamConfig`, :py:class:`.optim.LambConfig` or :py:class:`.optim.SGDConfig`. - loss_fn (callable, default is None): a PyTorch loss function. - It takes two inputs [prediction, label] and outputs a scalar loss tensor. - If provided, :py:attr:`loss_fn` is combined with the PyTorch :py:attr:`model` to form a combined PyTorch model. - Inputs to the combined PyTorch model are concatenation of the :py:attr:`model`'s input and :py:attr:`loss_fn`'s label input. - Outputs of the combined PyTorch model are concatenation of :py:attr:`loss_fn`'s loss output and :py:attr:`model`'s outputs. - options (ORTTrainerOptions, default is None): options for additional features. - Example: - - .. code-block:: python - - model = ... - loss_fn = ... - model_desc = { - "inputs": [ - ("input_ids", ["batch", "max_seq_len_in_batch"]), - ("attention_mask", ["batch", "max_seq_len_in_batch"]), - ("token_type_ids", ["batch", "max_seq_len_in_batch"]), - ("masked_lm_labels", ["batch", "max_seq_len_in_batch"]), - ("next_sentence_label", ["batch", 1]) - ], - "outputs": [ - ("loss", [], True), - ], - } - optim_config = optim.LambConfig(param_groups = [ { 'params' : ['model_param0'], 'alpha' : 0.8, 'beta' : 0.7}, - { 'params' : ['model_param1' , 'model_param_2'], 'alpha' : 0.0} - ], - alpha=0.9, beta=0.999) - ort_trainer = ORTTrainer(model, model_desc, optim_config, loss_fn) - """ - - def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None): - warnings.warn( - "ORTTrainer is deprecated and will be removed in ort release 1.14. Please use ORTModule instead.", - FutureWarning, - ) - - assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model" - assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'" - assert isinstance( - optim_config, optim._OptimizerConfig - ), "'optim_config' is required and must be any of 'AdamConfig', 'LambConfig' or 'SGDConfig'" - assert loss_fn is None or ( - callable(loss_fn) and len(signature(loss_fn).parameters) == 2 - ), "'loss_fn' must be either 'None' or a callable with two parameters" - assert options is None or isinstance( - options, ORTTrainerOptions - ), "'options' must be either 'None' or 'ORTTrainerOptions'" - - # Model + Loss validation - # Supported combinarios are - # ---------------------------------------- - # | | Model | Loss | - # ---------------------------------------- - # | 1 | torch.nn.Module | None | - # | 2 | torch.nn.Module | torch.nn.Module | - # | 3 | ONNX | None | - # ---------------------------------------- - self._torch_model = None - self._onnx_model = None - if isinstance(model, torch.nn.Module): - assert loss_fn is None or isinstance( - model, torch.nn.Module - ), "'loss_fn' must be either 'None' or 'torch.nn.Module'" - self._torch_model = model - self.loss_fn = loss_fn - # TODO: Remove when experimental checkpoint functions are removed. - self._torch_state_dict_keys = list(model.state_dict().keys()) - elif isinstance(model, onnx.ModelProto): - assert loss_fn is None, "'loss_fn' must not be specified when 'model' is an ONNX model" - self._onnx_model = model - self.loss_fn = None - else: - raise ValueError("'model' must be either 'torch.nn.Module' or 'onnx.ModelProto'") - - self.model_desc = _ORTTrainerModelDesc(model_desc) - self.optim_config = optim_config - - # ORTTrainerOptions - if not options: - options = ORTTrainerOptions() - self.options = options - if self.options.mixed_precision.enabled and not self.options.mixed_precision.loss_scaler: - # TODO: Move this to model_desc_validation.py - self.options.mixed_precision.loss_scaler = amp.loss_scaler.DynamicLossScaler() - # Post processing ONNX model given as input - if self._onnx_model: - if self.options._internal_use.enable_internal_postprocess: - self._onnx_model = postprocess.run_postprocess(self._onnx_model) - if self.options._internal_use.extra_postprocess: - self._onnx_model = self.options._internal_use.extra_postprocess(self._onnx_model) - assert isinstance(self._onnx_model, onnx.ModelProto), "'extra_postprocess' must return a ONNX model" - - # When input model is already ONNX (and not exported from Pytorch within ORTTrainer), - # append 'dtype' from ONNX into model description's - for idx_i, i_desc in enumerate(self.model_desc.inputs): - dtype = None - for onnx_input in self._onnx_model.graph.input: - if onnx_input.name == i_desc.name: - dtype = _utils.dtype_onnx_to_torch(onnx_input.type.tensor_type.elem_type) - self.model_desc.add_type_to_input_description(idx_i, dtype) - break - assert dtype is not None, f"ONNX model with unknown input type ({i_desc.name})" - for idx_o, o_desc in enumerate(self.model_desc.outputs): - dtype = None - for onnx_output in self._onnx_model.graph.output: - if onnx_output.name == o_desc.name: - dtype = _utils.dtype_onnx_to_torch(onnx_output.type.tensor_type.elem_type) - self.model_desc.add_type_to_output_description(idx_o, dtype) - break - assert dtype is not None, f"ONNX model with unknown output type ({o_desc.name})" - - try: - from torch.utils.cpp_extension import ROCM_HOME - - self.is_rocm_pytorch = bool(torch.version.hip is not None and ROCM_HOME is not None) - except ImportError: - self.is_rocm_pytorch = False - - # TODO: Remove when experimental checkpoint functions are removed. - self._state_dict = {} - - self._train_step_info = TrainStepInfo(self.optim_config) - self._training_session = None - self._load_state_dict = None - self._init_session( - provider_options=self.options._validated_opts["provider_options"], - session_options=self.options.session_options, - ) - - def eval_step(self, *args, **kwargs): - r"""Evaluation step method - - Args: - *args: Arbitrary arguments that are used as model input (data only) - **kwargs: Arbitrary keyword arguments that are used as model input (data only) - - Returns: - ordered :py:obj:`list` with model outputs as described by :py:attr:`.ORTTrainer.model_desc` - """ - # Get data. CombineTorchModelLossFn takes label as last input and outputs loss first - sample_input = self._prepare_model_input(self.model_desc.inputs, None, None, *args, **kwargs) - - # Export model to ONNX - if self._onnx_model is None: - if self._torch_model is not None: - self._init_onnx_model(sample_input) - else: - raise RuntimeError("Model is uninitialized. Only ONNX and PyTorch models are supported") - - # Prepare input/output description - inputs_desc = self.model_desc.inputs - outputs_desc = self.model_desc.outputs - if self._train_step_info.fetches: - outputs_desc = [o_desc for o_desc in outputs_desc if o_desc.name in self._train_step_info.fetches] - if len(outputs_desc) != len(self._train_step_info.fetches): - raise RuntimeError("The specified fetches list contains invalid output names") - - # Normalize input - if not isinstance(sample_input, (list, tuple)): - sample_input = (sample_input,) - - # RunOptions - run_options = ort.RunOptions() - run_options.only_execute_path_to_fetches = True - run_options.training_mode = False - - # Run a eval step and return - session_run_results = self._training_session_run_helper( - False, sample_input, inputs_desc, outputs_desc, run_options - ) - - # Output must be returned in the same order as defined in the model description - results = [session_run_results[o_desc.name] for o_desc in outputs_desc] - return results[0] if len(results) == 1 else results - - def save_as_onnx(self, path): - r"""Persists ONNX model into :py:attr:`path` - - The model will be saved as a Google Protocol Buffers (aka protobuf) file as per ONNX standard. - The graph includes full information, including inference and training metadata. - - Args: - path (str): Full path, including filename, to save the ONNX model in the filesystem - - Raises: - RuntimeWarning: raised when neither `train_step` or `eval_step` was called at least once - ValueError: raised when `path` is not valid path - """ - if not self._training_session: - warnings.warn( - "Training session is not initialized yet. " - "'train_step' or 'eval_step' methods must be executed at least once before calling 'save_as_onnx()'." - ) - return - state_tensors = self._training_session.get_state() - self._update_onnx_model_initializers(state_tensors) - - assert isinstance(path, str), "'path' must be a valid path string" - dir_name = os.path.dirname(path) - file_name = os.path.basename(path) - if (dir_name and not os.path.exists(dir_name)) or not file_name: - warnings.warn("'path' is not valid or does not exist") - return - - with open(path, "wb") as f: - f.write(self._onnx_model.SerializeToString()) - - def _check_model_export(self, input): - from numpy.testing import assert_allclose - from onnx import TensorProto, helper, numpy_helper # noqa: F401 - - onnx_model_copy = copy.deepcopy(self._onnx_model) - - # Mute the dropout nodes - dropout_nodes = [n for n in onnx_model_copy.graph.node if n.op_type == "Dropout"] - for node in dropout_nodes: - ratio_node = next(n for n in onnx_model_copy.graph.node if node.input[1] in n.output) - training_mode_node = next(n for n in onnx_model_copy.graph.node if node.input[2] in n.output) - - training_mode_node.attribute.pop() - ratio_node.attribute.pop() - new_training_mode_arr = np.array(False, dtype=bool) - new_ratio_arr = np.array(0.0, dtype=np.float32) - new_training_mode = numpy_helper.from_array(new_training_mode_arr) - new_ratio = numpy_helper.from_array(new_ratio_arr) - training_mode_node.attribute.add().t.CopyFrom(new_training_mode) - ratio_node.attribute.add().t.CopyFrom(new_ratio) - training_mode_node.attribute[0].type = 4 - ratio_node.attribute[0].type = 4 - training_mode_node.attribute[0].name = "value" - ratio_node.attribute[0].name = "value" - - _inference_sess = ort.InferenceSession( - onnx_model_copy.SerializeToString(), providers=ort.get_available_providers() - ) - inf_inputs = {} - for i, input_elem in enumerate(input): - inf_inputs[_inference_sess.get_inputs()[i].name] = input_elem.cpu().numpy() - _inference_outs = _inference_sess.run(None, inf_inputs) - for torch_item, ort_item in zip(self.torch_sample_outputs, _inference_outs): - assert_allclose( - torch_item, - ort_item, - rtol=1e-2, - atol=1e-6, - err_msg="Mismatch between outputs of PyTorch model and exported ONNX model. " - "Note that different backends may exhibit small computational differences." - "If this is within acceptable margin, or if there is random generator " - "in the model causing inevitable mismatch, you can proceed training by " - "setting the flag debug.check_model_export to False.", - ) - - def train_step(self, *args, **kwargs): - r"""Train step method - - After forward pass, an ordered list with all outputs described at :py:attr:`ORTTrainer.model_desc` is returned. - Additional information relevant to the train step is maintend by :py:attr:`ORTTrainer._train_step_info`. - See :py:class:`.TrainStepInfo` for details. - - Args: - *args: Arbitrary arguments that are used as model input (data only) - **kwargs: Arbitrary keyword arguments that are used as model input (data only) - - Returns: - ordered :py:obj:`list` with model outputs as described by :py:attr:`ORTTrainer.model_desc` - """ - # Export model to ONNX - if self._onnx_model is None: - sample_input = self._prepare_model_input(self.model_desc.inputs, None, None, *args, **kwargs) - self._init_onnx_model(sample_input) - - # Debug Model Export if indicated - if self.options.debug.check_model_export: - self._check_model_export(sample_input) - - # Prepare inputs+lr and output descriptions - inputs_desc = self._model_desc_inputs_with_lr - outputs_desc = self.model_desc.outputs - - # Train step must be incremented *before* gradient accumulation code - # Gradients are accumulated when - # self._train_step_info.step % self.options.batch.gradient_accumulation_steps != 0, - # and they are updated otherwise - self._train_step_info.step += 1 - - # RunOptions - run_options = None - mixed_precision_without_fetches = False - if self._train_step_info.fetches: - outputs_desc = [o_desc for o_desc in outputs_desc if o_desc.name in self._train_step_info.fetches] - if len(outputs_desc) != len(self._train_step_info.fetches): - raise RuntimeError("The specified fetches list contains invalid output names") - elif self._train_step_info.step % self.options.batch.gradient_accumulation_steps != 0: - run_options = ort.RunOptions() - run_options.only_execute_path_to_fetches = True - outputs_desc = self._model_desc_outputs_with_gradient_accumulation - elif self.options.mixed_precision.enabled: - mixed_precision_without_fetches = True - outputs_desc = self._model_desc_outputs_with_all_finite - - # Update Learning Rate if Necessary - lr = self.optim_config.lr - if self.options.lr_scheduler: - lr = self.options.lr_scheduler._step(self._train_step_info)[0] - - # Loss Scale for mixed precision - loss_scale = None - if self.options.mixed_precision.enabled: - loss_scaler = self.options.mixed_precision.loss_scaler - assert loss_scaler, "Loss scaler is required when mixed precision is enabled" - loss_scale = loss_scaler.loss_scale - inputs_desc = self._model_desc_inputs_with_lr_and_loss_scale - - # Get data. CombineTorchModelLossFn takes label as last input and outputs loss first - input = self._prepare_model_input(inputs_desc, lr, loss_scale, *args, **kwargs) - - # Normalize input - if not isinstance(args, (list, tuple)): - args = (args,) - - # Run a train step and return - session_run_results = self._training_session_run_helper(True, input, inputs_desc, outputs_desc, run_options) - if mixed_precision_without_fetches: - # After session run with all_fp32_gradients_finite, we need to clear the training I/O binding's output - # Otherwise next run with only_execute_path_to_fetches will lead to gradient all reduce - # because all_fp32_gradients_finite is still in the feed. - self._train_io_binding.clear_binding_outputs() - - is_all_finite = session_run_results[self.model_desc.all_finite.name] - self._train_step_info.all_finite = is_all_finite - if loss_scaler: - loss_scaler.update(self._train_step_info) - if is_all_finite: - # Optimization step must be incremented *after* optimization is successful - self._train_step_info.optimization_step += 1 - elif self._train_step_info.step % self.options.batch.gradient_accumulation_steps == 0: - # Optimization step must be incremented *after* optimization is successful - self._train_step_info.optimization_step += 1 - - # Output must be returned in the same order as defined in the model description - # or in the order specified by TrainStepInfo.fetches, if applicable - if self._train_step_info.fetches: - results = [session_run_results[o_desc] for o_desc in self._train_step_info.fetches] - else: - results = [session_run_results[o_desc.name] for o_desc in self.model_desc.outputs] - return results[0] if len(results) == 1 else results - - def _convert_torch_model_loss_fn_to_onnx(self, inputs, device): - # Dynamic axes - dynamic_axes = {} - for input in self.model_desc.inputs: - symbolic_axis = {} - for i, axis in enumerate(input.shape): - if isinstance(axis, str): - symbolic_axis[i] = axis - if len(symbolic_axis): - dynamic_axes[input.name] = symbolic_axis - for output in self.model_desc.outputs: - symbolic_axis = {} - for i, axis in enumerate(output.shape): - if isinstance(axis, str): - symbolic_axis[i] = axis - if len(symbolic_axis): - dynamic_axes[output.name] = symbolic_axis - - if isinstance(inputs, torch.Tensor): - inputs = [inputs] - if isinstance(inputs, dict): - sample_inputs = [inputs[k.name_].to(device=device) for k in self.model_desc.inputs] - elif isinstance(inputs, (list, tuple)): - sample_inputs = [ - input.to(device=device) for i, input in enumerate(inputs) if i < len(self.model_desc.inputs) - ] - else: - raise RuntimeError( - "Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported." - ) - - # PyTorch ONNX exporter does not match argument names - # This is an issue because the ONNX graph depends on all inputs to be specified - - # Validate loss_fn - if self.loss_fn: - sig_loss = signature(self.loss_fn) - if len(sig_loss.parameters) != 2: - raise RuntimeError("loss function should take two arguments - predict and label.") - - # Basic input names from model - input_names = [input.name for input in self.model_desc.inputs] - sig = signature(self._torch_model.forward) - ordered_input_list = list(sig.parameters.keys()) - - # Label from loss_fn goes after model input - if self.loss_fn: - ordered_input_list = [*ordered_input_list, list(sig_loss.parameters.keys())[1]] - - class CombineTorchModelLossFnWrapInput(torch.nn.Module): - def __init__(self, model, loss_fn, input_names): - super().__init__() - self.model = model - self.loss_fn = loss_fn - self.input_names = input_names - - def forward(self, *inputs): - sig = signature(self.model.forward) - - input_dict = {} - for key in sig.parameters: - if key in self.input_names: - input_dict[key] = inputs[self.input_names.index(key)] - - model_out = self.model(**input_dict) - if self.loss_fn is None: - return model_out - - label = inputs[-1] - preds = model_out - return self.loss_fn(preds, label), preds - - model = CombineTorchModelLossFnWrapInput(self._torch_model, self.loss_fn, input_names) - - # Do an inference to grab output types - model.eval() - with torch.no_grad(): - # Deepcopy inputs, since input values may change after model run. - sample_inputs_copy = copy.deepcopy(sample_inputs) - try: - # Deepcopy model, in case model is stateful and changes after model run. - model_copy = copy.deepcopy(model) - except Exception: - model_copy = model - warnings.warn( - "This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX." - " Compute will continue, but unexpected results may occur!" - ) - sample_outputs = model_copy(*sample_inputs_copy) - self.torch_sample_outputs = sample_outputs - model.train() - - if isinstance(sample_outputs, torch.Tensor): - sample_outputs = [sample_outputs] - - # Append 'dtype' for model description's inputs/outputs - for idx_i, sample_input in enumerate(sample_inputs): - if idx_i < len(self.model_desc.inputs): - self.model_desc.add_type_to_input_description(idx_i, sample_input.dtype) - for idx_o, sample_output in enumerate(sample_outputs): - if idx_o < len(self.model_desc.outputs): - self.model_desc.add_type_to_output_description(idx_o, sample_output.dtype) - - # Export the model to ONNX - f = io.BytesIO() - - # Deepcopy inputs, since input values may change after model run. - sample_inputs_copy = copy.deepcopy(sample_inputs) - - # Handle contrib OPs support - from onnxruntime.tools import pytorch_export_contrib_ops - - if self.options._internal_use.enable_onnx_contrib_ops: - pytorch_export_contrib_ops.register() - else: - # Unregister in case they were registered in previous calls. - pytorch_export_contrib_ops.unregister() - - # Export torch.nn.Module to ONNX - torch.onnx.export( - model, - tuple(sample_inputs_copy), - f, - input_names=[input.name for input in self.model_desc.inputs], - output_names=[output.name for output in self.model_desc.outputs], - opset_version=self.options._internal_use.onnx_opset_version, - dynamic_axes=dynamic_axes, - do_constant_folding=False, - training=torch.onnx.TrainingMode.TRAINING, - ) - onnx_model = onnx.load_model_from_string(f.getvalue()) - - # Remove 'model.' prefix introduced by CombineTorchModelLossFn class - if isinstance(model, CombineTorchModelLossFnWrapInput): - replace_name_dict = {} - for n in onnx_model.graph.initializer: - if n.name.startswith("model."): - replace_name_dict[n.name] = n.name[len("model.") :] - n.name = replace_name_dict[n.name] - for n in onnx_model.graph.node: - for i, name in enumerate(n.input): - if name in replace_name_dict: - n.input[i] = replace_name_dict[name] - - return onnx_model - - def _create_ort_training_session(self, optimizer_state_dict=None, session_options=None, provider_options=None): - if optimizer_state_dict is None: - optimizer_state_dict = {} - # Validating frozen_weights names - unused_frozen_weights = [ - n - for n in self.options.utils.frozen_weights - if n not in [i.name for i in self._onnx_model.graph.initializer] - ] - if unused_frozen_weights: - raise RuntimeError(f"{unused_frozen_weights} params from 'frozen_weights' not found in the ONNX model.") - - # Get loss name from model description - loss_name = [item.name for item in self.model_desc.outputs if item.is_loss] - assert len(loss_name) == 1, f"Only one loss output is supported ({len(loss_name)} were specified)" - loss_name = loss_name[0] - - # Parse optimizer parameters - optimizer_attributes_map = {} - optimizer_int_attributes_map = {} - trainable_params = set() - for initializer in self._onnx_model.graph.initializer: - if initializer.name in self.options.utils.frozen_weights: - continue # only trainable parameters are passed to the backend - trainable_params.add(initializer.name) - optimizer_attributes_map[initializer.name] = {} - optimizer_int_attributes_map[initializer.name] = {} - not_in_param_groups = True - for param_group in self.optim_config.params: - if initializer.name not in param_group["params"]: - continue # keep looking for a matching param_group - not_in_param_groups = False - for k, v in param_group.items(): - # 'params' is not a hyper parameter, skip it. 'lr' per weight is not supported - if k == "params" or k == "lr": - continue - if isinstance(v, float): - optimizer_attributes_map[initializer.name][k] = v - elif isinstance(v, int): - optimizer_int_attributes_map[initializer.name][k] = v - else: - raise ValueError("Optimizer attributes must be either float or int.") - - # set default values for params not found in groups - if not_in_param_groups: - for k, v in self.optim_config.defaults.items(): - if k == "lr": - continue - if isinstance(v, float): - optimizer_attributes_map[initializer.name][k] = v - elif isinstance(v, int): - optimizer_int_attributes_map[initializer.name][k] = v - else: - raise ValueError("Optimizer attributes must be either float or int.") - - self.options.distributed.horizontal_parallel_size = max(self.options.distributed.horizontal_parallel_size, 1) - self.options.distributed.data_parallel_size = ( - self.options.distributed.world_size // self.options.distributed.horizontal_parallel_size - ) - - # TrainingParameters - ort_parameters = ort.TrainingParameters() - ort_parameters.loss_output_name = loss_name - ort_parameters.use_mixed_precision = self.options.mixed_precision.enabled - ort_parameters.world_rank = self.options.distributed.world_rank - ort_parameters.world_size = self.options.distributed.world_size - ort_parameters.gradient_accumulation_steps = self.options.batch.gradient_accumulation_steps - ort_parameters.allreduce_post_accumulation = self.options.distributed.allreduce_post_accumulation - ort_parameters.enable_adasum = self.options.distributed.enable_adasum - ort_parameters.deepspeed_zero_stage = self.options.distributed.deepspeed_zero_optimization.stage - ort_parameters.enable_grad_norm_clip = self.options.utils.grad_norm_clip - ort_parameters.set_gradients_as_graph_outputs = False - ort_parameters.use_memory_efficient_gradient = self.options.utils.memory_efficient_gradient - ort_parameters.training_optimizer_name = self.optim_config.name - ort_parameters.lr_params_feed_name = self.model_desc.learning_rate.name - ort_parameters.weights_to_train = trainable_params - ort_parameters.optimizer_attributes_map = optimizer_attributes_map - ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map - if bool(optimizer_state_dict): - ort_parameters.set_optimizer_initial_state(optimizer_state_dict) - - ort_parameters.attn_dropout_recompute = self.options.graph_transformer.attn_dropout_recompute - ort_parameters.gelu_recompute = self.options.graph_transformer.gelu_recompute - ort_parameters.transformer_layer_recompute = self.options.graph_transformer.transformer_layer_recompute - ort_parameters.number_recompute_layers = self.options.graph_transformer.number_recompute_layers - - ort_parameters.data_parallel_size = self.options.distributed.data_parallel_size - ort_parameters.horizontal_parallel_size = self.options.distributed.horizontal_parallel_size - ort_parameters.pipeline_parallel_size = self.options.distributed.pipeline_parallel.pipeline_parallel_size - ort_parameters.num_pipeline_micro_batches = ( - self.options.distributed.pipeline_parallel.num_pipeline_micro_batches - ) - ort_parameters.pipeline_cut_info_string = self.options.distributed.pipeline_parallel.pipeline_cut_info_string - # We have special handling for dictionary-typed option. - # sliced_schema._validated_opts is the original dictionary while sliced_schema is a _ORTTrainerOptionsInternal. - ort_parameters.sliced_schema = self.options.distributed.pipeline_parallel.sliced_schema._validated_opts - # We have special handling for dictionary-typed option. - # sliced_axes._validated_opts is the original dictionary while sliced_schema is a _ORTTrainerOptionsInternal. - ort_parameters.sliced_axes = self.options.distributed.pipeline_parallel.sliced_axes._validated_opts - ort_parameters.sliced_tensor_names = self.options.distributed.pipeline_parallel.sliced_tensor_names - - ort_parameters.model_after_graph_transforms_path = ( - self.options.debug.graph_save_paths.model_after_graph_transforms_path - ) - ort_parameters.model_with_gradient_graph_path = ( - self.options.debug.graph_save_paths.model_with_gradient_graph_path - ) - ort_parameters.model_with_training_graph_path = ( - self.options.debug.graph_save_paths.model_with_training_graph_path - ) - - # SessionOptions - session_options = ort.SessionOptions() if session_options is None else session_options - session_options.use_deterministic_compute = self.options.debug.deterministic_compute - if ( - self.options.graph_transformer.attn_dropout_recompute - or self.options.graph_transformer.gelu_recompute - or self.options.graph_transformer.transformer_layer_recompute - ): - session_options.execution_order = ort.ExecutionOrder.PRIORITY_BASED - if len(self.options.debug.graph_save_paths.model_with_training_graph_after_optimization_path) > 0: - session_options.optimized_model_filepath = ( - self.options.debug.graph_save_paths.model_with_training_graph_after_optimization_path - ) - - # old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error. - # for example, load_state_dict will be called before returing the function, and it calls _init_session again - del self._training_session - - # Set provider-specific options if needed - def get_providers(provider_options): - providers = ort.get_available_providers() - if provider_options: - for provider_name in provider_options: - if provider_name in providers: - providers[providers.index(provider_name)] = (provider_name, provider_options[provider_name]) - else: - providers.insert(0, (provider_name, provider_options[provider_name])) - # default: using cuda - elif "cuda" in self.options.device.id.lower(): - gpu_ep_options = {"device_id": _utils.get_device_index(self.options.device.id)} - gpu_ep_name = "ROCMExecutionProvider" if self.is_rocm_pytorch else "CUDAExecutionProvider" - if self.options.device.mem_limit > 0: - gpu_ep_options["gpu_mem_limit"] = self.options.device.mem_limit - - if gpu_ep_name not in providers: - raise RuntimeError( - "ORTTrainer options specify a CUDA device but the {} provider is unavailable.".format( - cuda_ep_name # noqa: F821 - ) - ) - - providers[providers.index(gpu_ep_name)] = (gpu_ep_name, gpu_ep_options) - - return providers - - # TrainingSession - self._training_session = ort.TrainingSession( - self._onnx_model.SerializeToString(), ort_parameters, session_options, get_providers(provider_options) - ) - - # I/O bindings - self._train_io_binding = self._training_session.io_binding() - self._eval_io_binding = self._training_session.io_binding() - - def _init_onnx_model(self, inputs): - if self._onnx_model is not None: - return - - if self._torch_model is not None: - # PyTorch model is moved to cpu to save GPU memory - self._torch_model.cpu() - - # PyTorch buffers (created using 'register_buffer') shouldn't be trained - torch_buffers = list(dict(self._torch_model.named_buffers()).keys()) - self.options.utils.frozen_weights.extend(torch_buffers) - - # Export to ONNX - self._onnx_model = self._convert_torch_model_loss_fn_to_onnx(inputs, "cpu") - - # Post processing for ONNX models expported from PyTorch - if self.options._internal_use.enable_internal_postprocess: - self._onnx_model = postprocess.run_postprocess(self._onnx_model) - if self.options._internal_use.extra_postprocess: - self._onnx_model = self.options._internal_use.extra_postprocess(self._onnx_model) - - optimizer_state_dict = {} - if self._load_state_dict: - optimizer_state_dict = self._load_state_dict() - - self._init_session( - optimizer_state_dict, - session_options=self.options.session_options, - provider_options=self.options._validated_opts["provider_options"], - ) - - def _init_session(self, optimizer_state_dict={}, session_options=None, provider_options=None): # noqa: B006 - if self._onnx_model is None: - return - - if self.options.utils.run_symbolic_shape_infer: - self._onnx_model = SymbolicShapeInference.infer_shapes( - self._onnx_model, auto_merge=True, guess_output_rank=True - ) - - # Create training session used by train_step - # pass all optimizer states to the backend - self._create_ort_training_session( - optimizer_state_dict, session_options=session_options, provider_options=provider_options - ) - - # Update model description to update dtype when mixed precision is enabled - # C++ backend modifies model's output dtype from float32 to float16 for mixed precision - # Note that for training we must use float32 and for evaluation we must use float16 - for idx, o_desc in enumerate(self.model_desc.outputs): - if ( - self.options.mixed_precision.enabled - and o_desc.dtype == torch.float32 - and not self._training_session.is_output_fp32_node(o_desc.name) - ): - self.model_desc.add_type_to_output_description(idx, o_desc.dtype, torch.float16) - - # Update model description - self._model_desc_inputs_with_lr = [*self.model_desc.inputs, self.model_desc.learning_rate] - - # Update Mixed Precision, if applicable - if self.options.mixed_precision.enabled: - self.model_desc.loss_scale_input = self._training_session.loss_scale_input_name - self._model_desc_inputs_with_lr_and_loss_scale = [ - *self._model_desc_inputs_with_lr, - self.model_desc.loss_scale_input, - ] - self.model_desc.all_finite = _utils.get_all_gradients_finite_name_from_session(self._training_session) - self._model_desc_outputs_with_all_finite = [*self.model_desc.outputs, self.model_desc.all_finite] - elif self.options.mixed_precision.loss_scaler: - raise ValueError("Loss Scaler cannot be specified when Mixed Precision is not enabled") - - # Update Loss Scaler Input Name, if applicable - if self.options.mixed_precision.enabled and self.options.mixed_precision.loss_scaler: - self.options.mixed_precision.loss_scaler.input_name = self.model_desc.loss_scale_input.name - elif not self.options.mixed_precision.enabled and self.options.mixed_precision.loss_scaler: - raise ValueError("Loss Scaler cannot be specified when Mixed Precision is not enabled") - - # Update Gradient Accumulation, if applicable - if self.options.batch.gradient_accumulation_steps > 1: - self.model_desc.gradient_accumulation = _utils.get_gradient_accumulation_name_from_session( - self._training_session - ) - self._model_desc_outputs_with_gradient_accumulation = [ - *self.model_desc.outputs, - self.model_desc.gradient_accumulation, - ] - - # TODO: Remove when experimental checkpoint functions are removed - if self._state_dict: - checkpoint.experimental_load_state_dict(self, self._state_dict, self._load_state_dict_strict) - self._state_dict_debug = self._state_dict - self._state_dict = {} - - def _prepare_model_input(self, inputs_desc, lr, loss_scale, *inputs, **kwargs): - # Normalize input to tuple of samples - if type(inputs) == tuple and len(inputs) == 1 and type(inputs[0]) == list: # noqa: E721 - input = tuple(inputs[0]) - else: - input = inputs - - # Append input from 'kwargs' - for input_desc in inputs_desc: - if input_desc.name in kwargs: - input = (*input, kwargs[input_desc.name]) - - # Append learning rate - extra_inputs = 0 - if lr is not None: - lr = torch.tensor([lr]) - input += (lr,) - extra_inputs += 1 - - # Append loss scale - if loss_scale is not None: - assert self.options.mixed_precision.enabled, "Loss scale cannot be used without mixed precision" - loss_scale = torch.tensor([loss_scale]) - input += (loss_scale,) - extra_inputs += 1 - - # Only assert length of input when fetches is not used - assert self._train_step_info.fetches or len(self.model_desc.inputs) + extra_inputs == len(input) - return input - - def _resolve_symbolic_dimensions(self, inputs, inputs_desc, outputs_desc): - outputs = copy.deepcopy(outputs_desc) - resolved_dims = {} - for input, i_desc in zip(inputs, inputs_desc): - for i_idx, i_axis in enumerate(i_desc.shape): - if isinstance(i_axis, str): - if i_axis not in resolved_dims: - resolved_dims[i_axis] = input.size()[i_idx] - else: - assert resolved_dims[i_axis] == input.size()[i_idx], f"Mismatch in dynamic shape {i_axis}" - - for o_desc in outputs: - for idx_o, o_axis in enumerate(o_desc.shape): - if isinstance(o_axis, str): - o_desc.shape[idx_o] = resolved_dims[o_axis] - - unknown_dim = [o_desc.name for dim in o_desc.shape for o_desc in outputs if isinstance(dim, str)] - if unknown_dim: - raise RuntimeError(f"Cannot execute model with unknown output dimensions ({unknown_dim}") - - return outputs - - def _training_session_run_helper(self, is_train, inputs, inputs_desc, outputs_desc, run_options=None): - # Select IO binding - if is_train: - iobinding = self._train_io_binding - else: - iobinding = self._eval_io_binding - - # Get the list of the actual session inputs because unused inputs can be removed. - input_nodes = self._training_session.get_inputs() - input_node_names = [input_node.name for input_node in input_nodes] - - # Bind input tensors - for input, input_desc in zip(inputs, inputs_desc): - if input_desc.name in input_node_names: - device_index = _utils.get_device_index_from_input(input) - iobinding.bind_input( - input_desc.name, - input.device.type, - device_index, - _utils.dtype_torch_to_numpy(input.dtype), - list(input.size()), - input.data_ptr(), - ) - - # Bind output tensors - outputs_desc_resolved = self._resolve_symbolic_dimensions(inputs, inputs_desc, outputs_desc) - result = {} - for output_desc in outputs_desc_resolved: - target_device = self.options.device.id - if self.options.mixed_precision.enabled and output_desc.name == self.model_desc.all_finite.name: - # Keep all finite flag on CPU to match backend implementation - # This prevents CPU -> GPU -> CPU copies between frontend and backend - target_device = "cpu" - # the self.options.device may be a device that pytorch does not recognize. - # in that case, we temporary prefer to leave the input/output on CPU and let ORT session - # to move the data between device and host. - # so output will be on the same device as input. - try: - torch.device(target_device) - except Exception: - # in this case, input/output must on CPU - assert input.device.type == "cpu" - target_device = "cpu" - - torch_tensor = torch.zeros( - output_desc.shape, - device=target_device, - dtype=output_desc.dtype_amp if output_desc.dtype_amp else output_desc.dtype, - ) - iobinding.bind_output( - output_desc.name, - torch_tensor.device.type, - _utils.get_device_index(target_device), - _utils.dtype_torch_to_numpy(torch_tensor.dtype), - list(torch_tensor.size()), - torch_tensor.data_ptr(), - ) - result[output_desc.name] = torch_tensor - - # Run a train/eval step - self._training_session.run_with_iobinding(iobinding, run_options) - return result - - def _update_onnx_model_initializers(self, state_tensors): - r"""Updates ONNX graph initializers with state_tensors's values - - Usually called to save or load an ONNX model. - - The tensors names of state_tensors are compared to all ONNX initializer tensors - and when the name matches, the ONNX graph is updated with the new value. - """ - assert isinstance(state_tensors, dict), "state_tensors must be a dict" - - new_weights = [] - replace_indices = [] - for i, w in enumerate(self._onnx_model.graph.initializer): - if w.name in state_tensors: - new_weights.append(onnx.numpy_helper.from_array(state_tensors[w.name], w.name)) - replace_indices.append(i) - replace_indices.sort(reverse=True) - for w_i in replace_indices: - del self._onnx_model.graph.initializer[w_i] - self._onnx_model.graph.initializer.extend(new_weights) - - def _extract_model_states(self, state_dict, pytorch_format): - """Extract model states from the training session and load into the state_dict""" - - model_states = self._training_session.get_model_state(include_mixed_precision_weights=False) - state_dict[_utils.state_dict_model_key()] = {} - - # extract trained model weights from the training session - for precision in model_states: - state_dict[_utils.state_dict_model_key()][precision] = {} - for model_state_key in model_states[precision]: - if pytorch_format: - state_dict[_utils.state_dict_model_key()][precision][model_state_key] = torch.from_numpy( - model_states[precision][model_state_key] - ) - else: - state_dict[_utils.state_dict_model_key()][precision][model_state_key] = model_states[precision][ - model_state_key - ] - - # extract untrained (frozen) model weights - for node in self._onnx_model.graph.initializer: - if ( - node.name not in state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()] - and node.name in self.options.utils.frozen_weights - ): - if pytorch_format: - state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()][ - node.name - ] = torch.from_numpy(onnx.numpy_helper.to_array(node)) - else: - state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()][ - node.name - ] = onnx.numpy_helper.to_array(node) - - def _extract_trainer_options(self, state_dict): - """Extract relevant trainer configuration and load it into the state_dict""" - - mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key() - zero_stage = _utils.state_dict_trainer_options_zero_stage_key() - world_rank = _utils.state_dict_trainer_options_world_rank_key() - world_size = _utils.state_dict_trainer_options_world_size_key() - optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key() - D_size = _utils.state_dict_trainer_options_data_parallel_size_key() # noqa: N806 - H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() # noqa: N806 - - state_dict[_utils.state_dict_trainer_options_key()] = {} - state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] = self.options.mixed_precision.enabled - state_dict[_utils.state_dict_trainer_options_key()][ - zero_stage - ] = self.options.distributed.deepspeed_zero_optimization.stage - state_dict[_utils.state_dict_trainer_options_key()][world_rank] = self.options.distributed.world_rank - state_dict[_utils.state_dict_trainer_options_key()][world_size] = self.options.distributed.world_size - state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] = self.optim_config.name - state_dict[_utils.state_dict_trainer_options_key()][D_size] = self.options.distributed.data_parallel_size - state_dict[_utils.state_dict_trainer_options_key()][H_size] = self.options.distributed.horizontal_parallel_size - - def _extract_train_step_info(self, state_dict): - """Extract train step info settings and save it into the state_dict""" - - optimization_step = _utils.state_dict_train_step_info_optimization_step_key() - step = _utils.state_dict_train_step_info_step_key() - - state_dict[_utils.state_dict_train_step_info_key()] = {} - state_dict[_utils.state_dict_train_step_info_key()][optimization_step] = self._train_step_info.optimization_step - state_dict[_utils.state_dict_train_step_info_key()][step] = self._train_step_info.step - - def state_dict(self, pytorch_format=False): - """Returns a dictionary with model, train step info and optionally, optimizer states - - The returned dictionary contains the following information: - - Model and optimizer states - - Required ORTTrainerOptions settings - - Distributed training information, such as but not limited to ZeRO - - Train step info settings - - Structure of the returned dictionary: - - When `pytorch_format = False` - schema: - { - "model": - { - type: dict, - schema: - { - "full_precision": - { - type: dict, - schema: - { - model_weight_name: - { - type: array - } - } - } - } - }, - "optimizer": - { - type: dict, - schema: - { - model_weight_name: - { - type: dict, - schema: - { - "Moment_1": - { - type: array - }, - "Moment_2": - { - type: array - }, - "Update_Count": - { - type: array, - optional: True # present if optimizer is adam, absent otherwise - } - } - }, - "shared_optimizer_state": - { - type: dict, - optional: True, # present optimizer is shared, absent otherwise. - schema: - { - "step": - { - type: array, - } - } - } - } - }, - "trainer_options": - { - type: dict, - schema: - { - "mixed_precision": - { - type: bool - }, - "zero_stage": - { - type: int - }, - "world_rank": - { - type: int - }, - "world_size": - { - type: int - }, - "optimizer_name": - { - type: str - }, - "data_parallel_size": - { - type: int - }, - "horizontal_parallel_size": - { - type: int - } - } - }, - "partition_info": - { - type: dict, - optional: True, # present if states partitioned, else absent - schema: - { - model_weight_name: - { - type: dict, - schema: - { - "original_dim": - { - type: array - }, - "megatron_row_partition": - { - type: int - } - } - } - } - }, - "train_step_info": - { - type: dict, - schema: - { - "optimization_step": - { - type: int - }, - "step": - { - type: int - } - } - } - } - - When `pytorch_format = True` - schema: - { - model_weight_name: - { - type: tensor - } - } - - Args: - pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema - - Returns: - A dictionary with `ORTTrainer` state - """ - if not self._training_session: - warnings.warn( - "ONNX Runtime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling ORTTrainer.state_dict().", - UserWarning, - ) - return self._load_state_dict.args[0] if self._load_state_dict else {} - - state_dict = {} - - # load training session model states into the state_dict - self._extract_model_states(state_dict, pytorch_format) - if pytorch_format: - if self.options.distributed.deepspeed_zero_optimization.stage > 0: - warnings.warn("Incomplete state_dict: ZeRO enabled", UserWarning) - if self.options.distributed.horizontal_parallel_size > 1: - warnings.warn("Incomplete state_dict: Megatron enabled", UserWarning) - # if pytorch_format is true, return a flat dictionary with only model states - # which is compatible with a PyTorch model - return state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()] - - # load training session optimizer states into the state_dict - state_dict[_utils.state_dict_optimizer_key()] = self._training_session.get_optimizer_state() - - # extract the relevant training configuration from the trainer and load them into the state_dict - self._extract_trainer_options(state_dict) - - # Extract train step info settings and load it into the state_dict - self._extract_train_step_info(state_dict) - - # add partition information in case of a distributed run - if ( - self.options.distributed.deepspeed_zero_optimization.stage > 0 - or self.options.distributed.horizontal_parallel_size > 1 - ): - state_dict[_utils.state_dict_partition_info_key()] = self._training_session.get_partition_info_map() - - return state_dict - - def _load_model_states(self, state_dict, strict): - """Load the model states onto the onnx model graph""" - - if _utils.state_dict_model_key() not in state_dict: - return - - # collect all initializer names from the current onnx graph - assert self._onnx_model, "ONNX model graph is not exported" - initializer_names = {node.name for node in self._onnx_model.graph.initializer} - - # loaded_initializers dict will be loaded with all the model states from the state dictionary - # that are found in the initializer_names dictionary - loaded_initializers = {} - - # copy over model states from the input state dict onto the onnx model - for precision, precision_states in state_dict[_utils.state_dict_model_key()].items(): - for state_key, state_value in precision_states.items(): - if state_key in initializer_names: - loaded_initializers[state_key] = state_value - elif strict: - raise RuntimeError(f"Unexpected key: {state_key} in state_dict[model][{precision}]") - - # update onnx model from loaded initializers - self._update_onnx_model_initializers(loaded_initializers) - - def _load_optimizer_states(self, current_state_dict, state_dict): - """Load the optimizer states onto the training session state dictionary""" - - def _check_optimizer_mismatch(state_dict): - """Assert that the loaded optimizer has the same config as the current training session config""" - - # the state_dict optimizer_name can be a byte string (if coming from checkpoint file) - # or can be a regular string (coming from user) - optimizer_name = state_dict[_utils.state_dict_trainer_options_key()][ - _utils.state_dict_trainer_options_optimizer_name_key() - ] - - # optimizer_name can be either a regular string or a byte string. - # if it is a byte string, convert to regular string using decode() - # if it is a regular string, do nothing to it - try: # noqa: SIM105 - optimizer_name = optimizer_name.decode() - except AttributeError: - pass - assert self.optim_config.name == optimizer_name, "Optimizer mismatch: expected {}, got {}".format( - self.optim_config.name, optimizer_name - ) - - if _utils.state_dict_optimizer_key() not in state_dict: - return - - # check optimizer config names are the same for current session and the sessino being loaded - _check_optimizer_mismatch(state_dict) - - # create an entry for the optimizer in the training session state dictionary - if _utils.state_dict_optimizer_key() not in current_state_dict: - current_state_dict[_utils.state_dict_optimizer_key()] = {} - - # copy over optimizer states from the input state dict onto the training session state dict - for model_state_key, optimizer_dict in state_dict[_utils.state_dict_optimizer_key()].items(): - if model_state_key not in current_state_dict[_utils.state_dict_optimizer_key()]: - current_state_dict[_utils.state_dict_optimizer_key()][model_state_key] = {} - for optimizer_state_key, optimizer_state_value in optimizer_dict.items(): - current_state_dict[_utils.state_dict_optimizer_key()][model_state_key][ - optimizer_state_key - ] = optimizer_state_value - - def _load_state_dict_impl(self, state_dict, strict=True): - """Load the state dictionary onto the onnx model and on the training session graph""" - - # clear the callable partial - self._load_state_dict = None - - def _mismatch_keys(keys1, keys2, in_error_str, allow_unexpected=False): - """Find out the missing and the unexpected keys in two dictionaries - - Throws a runtime error if missing or unexpected keys are found - - Keys in keys1 not in keys2 will be marked as missing - - Keys in keys2 not in keys1 will be marked as unexpected - """ - keys1 = set(keys1) - keys2 = set(keys2) - missing_keys = list(keys1 - keys2) - unexpected_keys = list(keys2 - keys1) - if len(missing_keys) > 0: - raise RuntimeError(f"Missing keys: {missing_keys} in {in_error_str}") - if len(unexpected_keys) > 0 and not allow_unexpected: - raise RuntimeError(f"Unexpected keys: {unexpected_keys} in {in_error_str}") - - def _check_model_key_mismatch(current_state_dict, state_dict, allow_unexpected=False): - """Check if there is any mismatch in the model sub state dictionary between the two state_dicts""" - - # check unxexpected and missing precision keys in the model state_dict compared to the training - # session model state_dict - _mismatch_keys( - current_state_dict[_utils.state_dict_model_key()], - state_dict[_utils.state_dict_model_key()], - "state_dict[model]", - allow_unexpected, - ) - - # check for model state key mismatch - for precision_key in current_state_dict[_utils.state_dict_model_key()]: - _mismatch_keys( - current_state_dict[_utils.state_dict_model_key()][precision_key], - state_dict[_utils.state_dict_model_key()][precision_key], - f"state_dict[model][{precision_key}]", - allow_unexpected, - ) - - def _check_optimizer_key_mismatch(current_state_dict, state_dict, allow_unexpected=False): - """Check if there is any mismatch in the optimizer sub state dictionary between the two state_dicts""" - - # check for model state key mismatch for the optimizer state_dict - _mismatch_keys( - current_state_dict[_utils.state_dict_optimizer_key()], - state_dict[_utils.state_dict_optimizer_key()], - "state_dict[optimizer]", - allow_unexpected, - ) - - # check for optimizer state keys mismatch - for model_state_key in current_state_dict[_utils.state_dict_optimizer_key()]: - _mismatch_keys( - current_state_dict[_utils.state_dict_optimizer_key()][model_state_key], - state_dict[_utils.state_dict_optimizer_key()][model_state_key], - f"state_dict[optimizer][{model_state_key}]", - allow_unexpected, - ) - - def _check_key_mismatch(current_state_dict, state_dict, allow_unexpected=False): - """Check if there is a mismatch in the keys (model and optimizer) in the two state_dicts""" - - # check presence of 'model' in the input state_dict - if _utils.state_dict_model_key() in state_dict: - _check_model_key_mismatch(current_state_dict, state_dict, allow_unexpected) - else: - warnings.warn("Missing key: model in state_dict", UserWarning) - # check presence of 'optimizer' in the input state_dict - if _utils.state_dict_optimizer_key() in state_dict: - _check_optimizer_key_mismatch(current_state_dict, state_dict, allow_unexpected) - else: - warnings.warn("Missing key: optimizer in state_dict", UserWarning) - - # extract state dict from the current training session. this is to persist the states between - # two training sessions. - # for example, if user provided only the model states, the optimizer states from the current - # training session must be persisted - current_state_dict = {} - if self._training_session: - current_state_dict = self.state_dict() - if strict: - # for Zero enabled, the current trainer might not have the complete state, and we must allow - # extra keys to be present in the state dict - allow_unexpected = self.options.distributed.deepspeed_zero_optimization.stage > 0 - _check_key_mismatch(current_state_dict, state_dict, allow_unexpected) - - # load the model states from the input state dictionary into the onnx graph - self._load_model_states(state_dict, strict) - - # load the optimizer states from the input state dictionary into the training session states - # dictionary - self._load_optimizer_states(current_state_dict, state_dict) - - return ( - current_state_dict[_utils.state_dict_optimizer_key()] - if _utils.state_dict_optimizer_key() in current_state_dict - else {} - ) - - def _load_train_step_info(self, state_dict): - """Load the train step info settings from state dict""" - - if _utils.state_dict_train_step_info_key() not in state_dict: - warnings.warn("Missing key: train_step_info in state_dict", UserWarning) - return - - optimization_step = _utils.state_dict_train_step_info_optimization_step_key() - step = _utils.state_dict_train_step_info_step_key() - - self._train_step_info.optimization_step = state_dict[_utils.state_dict_train_step_info_key()][optimization_step] - self._train_step_info.step = state_dict[_utils.state_dict_train_step_info_key()][step] - - def load_state_dict(self, state_dict, strict=True): - """Loads state_dict containing model/optimizer states into ORTTrainer - - The state_dict dictionary may contain the following information: - - Model and optimizer states - - Required ORTTrainerOptions settings - - Distributed training information, such as but not limited to ZeRO - - Args: - state_dict: state dictionary containing both model and optimizer states. The structure of this dictionary - should be the same as the one that is returned by ORTTrainer.state_dict for the case when pytorch_format=False - strict: boolean flag to strictly enforce that the input state_dict keys match the keys from ORTTrainer.state_dict - """ - - # if onnx graph has not been initialized, loading of states will be put on hold. - # a copy of the state_dict and other arguments to the function will be stored until the onnx graph has - # been initialized. Once the graph is initialized, the desired states will be loaded onto the grpah - if not self._training_session: - self._load_state_dict = partial(self._load_state_dict_impl, state_dict, strict=strict) - return - - # load the train step info settings - self._load_train_step_info(state_dict) - - # load states onto the frontend onnx graph - optimizer_state_dict = self._load_state_dict_impl(state_dict, strict=strict) - - # create a new training session after loading initializer states onto the onnx graph - # pass the populated states to the training session to populate the backend graph - self._init_session( - optimizer_state_dict, - session_options=self.options.session_options, - provider_options=self.options._validated_opts["provider_options"], - ) - - def save_checkpoint(self, path, user_dict={}, include_optimizer_states=True): # noqa: B006 - """Persists ORTTrainer state dictionary on disk along with user_dict. - - Saves the state_dict along with the user_dict to a file specified by path. - - Args: - path: string representation to a file path or a python file-like object. - if file already exists at path, an exception is raised. - user_dict: custom data to be saved along with the state_dict. This data will be returned - to the user when load_checkpoint is called. - include_optimizer_states: boolean flag indicating whether or not to persist the optimizer states. - on load_checkpoint, only model states will be loaded if include_optimizer_states==True - """ - - # extract state_dict to be saved in the checkpoint - state_dict = self.state_dict() - - # if user_dict is provided, serialize to bytes and convert to hex string. - # this helps in loading the types as they are given by the user since hdf5 - # converts to numpy types otherwise - if bool(user_dict): - state_dict[_utils.state_dict_user_dict_key()] = _checkpoint_storage.to_serialized_hex(user_dict) - - # if include_optimizer_states is False, only save the model states in the checkpoint file - if not include_optimizer_states: - if _utils.state_dict_optimizer_key() in state_dict: - del state_dict[_utils.state_dict_optimizer_key()] - - _checkpoint_storage.save(state_dict, path) - - def _aggregation_required(self, loaded_trainer_options): - """Checks if aggregation is required for the loading the state_dict into the ORTTrainer""" - - # To load states in the backend, aggregation is required for every ZeRO - # or Megatron checkpoint - return ( - loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 - or loaded_trainer_options[_utils.state_dict_trainer_options_horizontal_parallel_size_key()] > 1 - ) - - def load_checkpoint(self, *paths, strict=True): - """Loads the saved checkpoint state dictionary into the ORTTrainer - - Reads the saved checkpoint files specified by paths from disk and loads the state dictionary - onto the ORTTrainer. - Aggregates the checkpoint files if aggregation is required. - - Args: - paths: one or more files represented as strings where the checkpoint is saved - strict: boolean flag to strictly enforce that the saved checkpoint state_dict - keys match the keys from ORTTrainer.state_dict - Returns: - dictionary that the user had saved when calling save_checkpoint - """ - state_dict = {} - - # check if aggregation is required - loaded_trainer_options = _checkpoint_storage.load(paths[0], key=_utils.state_dict_trainer_options_key()) - if self._aggregation_required(loaded_trainer_options): - # if aggregation is required, aggregation logic must be run on the saved checkpoints - state_dict = checkpoint.aggregate_checkpoints(paths, pytorch_format=False) - else: - # if aggregation is not required, there must only be a single file that needs to be loaded - assert len(paths) == 1, f"Expected number of files to load: 1, got {len(paths)}" - state_dict = _checkpoint_storage.load(paths[0]) - - # extract user dict from the saved checkpoint - user_dict = {} - if _utils.state_dict_user_dict_key() in state_dict: - user_dict = _checkpoint_storage.from_serialized_hex(state_dict[_utils.state_dict_user_dict_key()]) - del state_dict[_utils.state_dict_user_dict_key()] - - self.load_state_dict(state_dict, strict=strict) - - return user_dict diff --git a/orttraining/orttraining/python/training/orttrainer_options.py b/orttraining/orttraining/python/training/orttrainer_options.py deleted file mode 100644 index c63ac6f82c87f..0000000000000 --- a/orttraining/orttraining/python/training/orttrainer_options.py +++ /dev/null @@ -1,692 +0,0 @@ -import cerberus - -import onnxruntime as ort -from onnxruntime.capi._pybind_state import PropagateCastOpsStrategy - -from .amp import loss_scaler -from .optim import lr_scheduler - - -class ORTTrainerOptions: - r"""Settings used by ONNX Runtime training backend - - The parameters are hierarchically organized to facilitate configuration through semantic groups - that encompasses features, such as distributed training, etc. - - Input validation is performed on the input dict during instantiation to ensure - that supported parameters and values are passed in. Invalid input results - in :py:obj:`ValueError` exception with details on it. - - Args: - options (dict): contains all training options - _validate (bool, default is True): for internal use only - - Supported schema for kwargs: - - .. code-block:: python - - schema = { - 'batch' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'gradient_accumulation_steps' : { - 'type' : 'integer', - 'min' : 1, - 'default' : 1 - } - }, - }, - 'device' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'id' : { - 'type' : 'string', - 'default' : 'cuda' - }, - 'mem_limit' : { - 'type' : 'integer', - 'min' : 0, - 'default' : 0 - } - } - }, - 'distributed': { - 'type': 'dict', - 'default': {}, - 'required': False, - 'schema': { - 'world_rank': { - 'type': 'integer', - 'min': 0, - 'default': 0 - }, - 'world_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'local_rank': { - 'type': 'integer', - 'min': 0, - 'default': 0 - }, - 'data_parallel_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'horizontal_parallel_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'pipeline_parallel' : { - 'type': 'dict', - 'default': {}, - 'required': False, - 'schema': { - 'pipeline_parallel_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'num_pipeline_micro_batches': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'pipeline_cut_info_string': { - 'type': 'string', - 'default': '' - }, - 'sliced_schema': { - 'type': 'dict', - 'default': {}, - 'keysrules': {'type': 'string'}, - 'valuesrules': { - 'type': 'list', - 'schema': {'type': 'integer'} - } - }, - 'sliced_axes': { - 'type': 'dict', - 'default': {}, - 'keysrules': {'type': 'string'}, - 'valuesrules': {'type': 'integer'} - }, - 'sliced_tensor_names': { - 'type': 'list', - 'schema': {'type': 'string'}, - 'default': [] - } - } - }, - 'allreduce_post_accumulation': { - 'type': 'boolean', - 'default': False - }, - 'deepspeed_zero_optimization': { - 'type': 'dict', - 'default': {}, - 'required': False, - 'schema': { - 'stage': { - 'type': 'integer', - 'min': 0, - 'max': 1, - 'default': 0 - }, - } - }, - 'enable_adasum': { - 'type': 'boolean', - 'default': False - } - } - }, - 'lr_scheduler' : { - 'type' : 'optim.lr_scheduler', - 'nullable' : True, - 'default' : None - }, - 'mixed_precision' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'enabled' : { - 'type' : 'boolean', - 'default' : False - }, - 'loss_scaler' : { - 'type' : 'amp.loss_scaler', - 'nullable' : True, - 'default' : None - } - } - }, - 'graph_transformer': { - 'type': 'dict', - 'required': False, - 'default': {}, - 'schema': { - 'attn_dropout_recompute': { - 'type': 'boolean', - 'default': False - }, - 'gelu_recompute': { - 'type': 'boolean', - 'default': False - }, - 'transformer_layer_recompute': { - 'type': 'boolean', - 'default': False - }, - 'number_recompute_layers': { - 'type': 'integer', - 'min': 0, - 'default': 0 - }, - 'propagate_cast_ops_config': { - 'type': 'dict', - 'required': False, - 'default': {}, - 'schema': { - 'propagate_cast_ops_strategy': { - 'type': 'onnxruntime.training.PropagateCastOpsStrategy', - 'default': PropagateCastOpsStrategy.FLOOD_FILL - }, - 'propagate_cast_ops_level': { - 'type': 'integer', - 'default': 1 - }, - 'propagate_cast_ops_allow': { - 'type': 'list', - 'schema': {'type': 'string'}, - 'default': [] - } - } - } - } - }, - 'utils' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'frozen_weights' : { - 'type' : 'list', - 'default' : [] - }, - 'grad_norm_clip' : { - 'type' : 'boolean', - 'default' : True - }, - 'memory_efficient_gradient' : { - 'type' : 'boolean', - 'default' : False - }, - 'run_symbolic_shape_infer' : { - 'type' : 'boolean', - 'default' : False - } - } - }, - 'debug' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'deterministic_compute' : { - 'type' : 'boolean', - 'default' : False - }, - 'check_model_export' : { - 'type' : 'boolean', - 'default' : False - }, - 'graph_save_paths' : { - 'type' : 'dict', - 'default': {}, - 'required': False, - 'schema': { - 'model_after_graph_transforms_path': { - 'type': 'string', - 'default': '' - }, - 'model_with_gradient_graph_path':{ - 'type': 'string', - 'default': '' - }, - 'model_with_training_graph_path': { - 'type': 'string', - 'default': '' - }, - 'model_with_training_graph_after_optimization_path': { - 'type': 'string', - 'default': '' - }, - } - }, - } - }, - '_internal_use' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'enable_internal_postprocess' : { - 'type' : 'boolean', - 'default' : True - }, - 'extra_postprocess' : { - 'type' : 'callable', - 'nullable' : True, - 'default' : None - }, - 'onnx_opset_version': { - 'type': 'integer', - 'min' : 12, - 'max' :14, - 'default': 14 - }, - 'enable_onnx_contrib_ops' : { - 'type' : 'boolean', - 'default' : True - } - } - }, - 'provider_options':{ - 'type': 'dict', - 'default': {}, - 'required': False, - 'schema': {} - }, - 'session_options': { - 'type': 'SessionOptions', - 'nullable': True, - 'default': None - }, - } - - Keyword arguments: - batch (dict): - batch related settings - batch.gradient_accumulation_steps (int, default is 1): - number of steps to accumulate before do collective gradient reduction - device (dict): - compute device related settings - device.id (string, default is 'cuda'): - device to run training - device.mem_limit (int): - maximum memory size (in bytes) used by device.id - distributed (dict): - distributed training options. - distributed.world_rank (int, default is 0): - rank ID used for data/horizontal parallelism - distributed.world_size (int, default is 1): - number of ranks participating in parallelism - distributed.data_parallel_size (int, default is 1): - number of ranks participating in data parallelism - distributed.horizontal_parallel_size (int, default is 1): - number of ranks participating in horizontal parallelism - distributed.pipeline_parallel (dict): - Options which are only useful to pipeline parallel. - distributed.pipeline_parallel.pipeline_parallel_size (int, default is 1): - number of ranks participating in pipeline parallelism - distributed.pipeline_parallel.num_pipeline_micro_batches (int, default is 1): - number of micro-batches. We divide input batch into micro-batches and run the graph. - distributed.pipeline_parallel.pipeline_cut_info_string (string, default is ''): - string of cutting ids for pipeline partition. - distributed.allreduce_post_accumulation (bool, default is False): - True enables overlap of AllReduce with computation, while False, - postpone AllReduce until all gradients are ready - distributed.deepspeed_zero_optimization: - DeepSpeed ZeRO options. - distributed.deepspeed_zero_optimization.stage (int, default is 0): - select which stage of DeepSpeed ZeRO to use. Stage 0 means disabled. - distributed.enable_adasum (bool, default is False): - enable `Adasum `_ - algorithm for AllReduce - lr_scheduler (optim._LRScheduler, default is None): - specifies learning rate scheduler - mixed_precision (dict): - mixed precision training options - mixed_precision.enabled (bool, default is False): - enable mixed precision (fp16) - mixed_precision.loss_scaler (amp.LossScaler, default is None): - specifies a loss scaler to be used for fp16. If not specified, - :py:class:`.DynamicLossScaler` is used with default values. - Users can also instantiate :py:class:`.DynamicLossScaler` and - override its parameters. Lastly, a completely new implementation - can be specified by extending :py:class:`.LossScaler` class from scratch - graph_transformer (dict): - graph transformer related configurations - graph_transformer.attn_dropout_recompute(bool, default False) - graph_transformer.gelu_recompute(bool, default False) - graph_transformer.transformer_layer_recompute(bool, default False) - graph_transformer.number_recompute_layers(bool, default False) - graph_transformer.propagate_cast_ops_config (dict): - graph_transformer.propagate_cast_ops_config.strategy(PropagateCastOpsStrategy, default FLOOD_FILL) - Specify the choice of the cast propagation optimization strategy, either, NONE, INSERT_AND_REDUCE or FLOOD_FILL. - NONE strategy does not perform any cast propagation transformation on the graph, although other optimizations - locally change cast operations, for example, in order to fuse Transpose and MatMul nodes, the TransposeMatMulFunsion optimization could - interchange Transpose and Cast if the Cast node exists between Transpose and MatMul. - INSERT_AND_REDUCE strategy inserts and reduces cast operations around the nodes with allowed opcodes. - FLOOD_FILL strategy expands float16 regions in the graph using the allowed opcodes, and unlike - INSERT_AND_REDUCE does not touch opcodes outside expanded float16 region. - graph_transformer.propagate_cast_ops_config.level(integer, default 1) - Optimize by moving Cast operations if propagate_cast_ops_level is non-negative. - Use predetermined list of opcodes considered safe to move before/after cast operation - if propagate_cast_ops_level is positive and use propagate_cast_ops_allow otherwise. - graph_transformer.propagate_cast_ops_config.allow(list of str, []) - List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero. - attn_dropout_recompute (bool, default is False): - enable recomputing attention dropout to save memory - gelu_recompute (bool, default is False): - enable recomputing Gelu activation output to save memory - transformer_layer_recompute (bool, default is False): - enable recomputing transformer layerwise to save memory - number_recompute_layers (int, default is 0) - number of layers to apply transformer_layer_recompute, by default system will - apply recompute to all the layers, except for the last one - utils (dict): - miscellaneous options - utils.frozen_weights (list of str, []): - list of model parameter names to skip training (weights don't change) - utils.grad_norm_clip (bool, default is True): - enables gradient norm clipping for 'AdamOptimizer' and 'LambOptimizer' - utils.memory_efficient_gradient (bool, default is False): - enables use of memory aware gradient builder. - utils.run_symbolic_shape_infer (bool, default is False): - runs symbolic shape inference on the model - debug (dict): - debug options - debug.deterministic_compute (bool, default is False) - forces compute to be deterministic accross runs - debug.check_model_export (bool, default is False) - compares PyTorch model outputs with ONNX model outputs in inference before the first - train step to ensure successful model export - debug.graph_save_paths (dict): - paths used for dumping ONNX graphs for debugging purposes - debug.graph_save_paths.model_after_graph_transforms_path (str, default is "") - path to export the ONNX graph after training-related graph transforms have been applied. - No output when it is empty. - debug.graph_save_paths.model_with_gradient_graph_path (str, default is "") - path to export the ONNX graph with the gradient graph added. No output when it is empty. - debug.graph_save_paths.model_with_training_graph_path (str, default is "") - path to export the training ONNX graph with forward, gradient and optimizer nodes. - No output when it is empty. - debug.graph_save_paths.model_with_training_graph_after_optimization_path (str, default is "") - outputs the optimized training graph to the path if nonempty. - _internal_use (dict): - internal options, possibly undocumented, that might be removed without notice - _internal_use.enable_internal_postprocess (bool, default is True): - enable internal internal post processing of the ONNX model - _internal_use.extra_postprocess (callable, default is None) - a functor to postprocess the ONNX model and return a new ONNX model. - It does not override :py:attr:`._internal_use.enable_internal_postprocess`, but complement it - _internal_use.onnx_opset_version (int, default is 14): - ONNX opset version used during model exporting. - _internal_use.enable_onnx_contrib_ops (bool, default is True) - enable PyTorch to export nodes as contrib ops in ONNX. - This flag may be removed anytime in the future. - session_options (onnxruntime.SessionOptions): - The SessionOptions instance that TrainingSession will use. - provider_options (dict): - The provider_options for customized execution providers. it is dict map from EP name to - a key-value pairs, like {'EP1' : {'key1' : 'val1'}, ....} - - Example: - .. code-block:: python - - opts = ORTTrainerOptions({ - 'batch' : { - 'gradient_accumulation_steps' : 128 - }, - 'device' : { - 'id' : 'cuda:0', - 'mem_limit' : 2*1024*1024*1024, - }, - 'lr_scheduler' : optim.lr_scheduler.LinearWarmupLRScheduler(), - 'mixed_precision' : { - 'enabled': True, - 'loss_scaler': amp.LossScaler(loss_scale=float(1 << 16)) - } - }) - fp16_enabled = opts.mixed_precision.enabled - """ - - def __init__(self, options={}): # noqa: B006 - # Keep a copy of original input for debug - self._original_opts = dict(options) - - # Used for logging purposes - self._main_class_name = self.__class__.__name__ - - # Validates user input - self._validated_opts = dict(self._original_opts) - validator = ORTTrainerOptionsValidator(_ORTTRAINER_OPTIONS_SCHEMA) - self._validated_opts = validator.validated(self._validated_opts) - if self._validated_opts is None: - raise ValueError(f"Invalid options: {validator.errors}") - - # Convert dict in object - for k, v in self._validated_opts.items(): - setattr(self, k, self._wrap(v)) - - def __repr__(self): - return "{%s}" % str( - ", ".join( - f"'{k}': {v!r}" - for (k, v) in self.__dict__.items() - if k not in ["_original_opts", "_validated_opts", "_main_class_name"] - ) - ) - - def _wrap(self, v): - if isinstance(v, (tuple, list, set, frozenset)): - return type(v)([self._wrap(i) for i in v]) - else: - return _ORTTrainerOptionsInternal(self._main_class_name, v) if isinstance(v, dict) else v - - -class _ORTTrainerOptionsInternal(ORTTrainerOptions): - r"""Internal class used by ONNX Runtime training backend for input validation - - NOTE: Users MUST NOT use this class in any way! - """ - - def __init__(self, main_class_name, options): - # Used for logging purposes - self._main_class_name = main_class_name - # We don't call super().__init__(options) here but still called it "_validated_opts" - # instead of "_original_opts" because it has been validated in the top-level - # ORTTrainerOptions's constructor. - self._validated_opts = dict(options) - # Convert dict in object - for k, v in dict(options).items(): - setattr(self, k, self._wrap(v)) - - -class ORTTrainerOptionsValidator(cerberus.Validator): - _LR_SCHEDULER = cerberus.TypeDefinition("lr_scheduler", (lr_scheduler._LRScheduler,), ()) - _LOSS_SCALER = cerberus.TypeDefinition("loss_scaler", (loss_scaler.LossScaler,), ()) - - _SESSION_OPTIONS = cerberus.TypeDefinition("session_options", (ort.SessionOptions,), ()) - - _PROPAGATE_CAST_OPS_STRATEGY = cerberus.TypeDefinition( - "propagate_cast_ops_strategy", (PropagateCastOpsStrategy,), () - ) - - types_mapping = cerberus.Validator.types_mapping.copy() - types_mapping["lr_scheduler"] = _LR_SCHEDULER - types_mapping["loss_scaler"] = _LOSS_SCALER - types_mapping["session_options"] = _SESSION_OPTIONS - types_mapping["propagate_cast_ops_strategy"] = _PROPAGATE_CAST_OPS_STRATEGY - - -def _check_is_callable(field, value, error): - result = False - try: - # Python 3 - result = value is None or callable(value) - except Exception: - # Python 3 but < 3.2 - if hasattr(value, "__call__"): # noqa: B004 - result = True - if not result: - error(field, "Must be callable or None") - - -_ORTTRAINER_OPTIONS_SCHEMA = { - "batch": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": {"gradient_accumulation_steps": {"type": "integer", "min": 1, "default": 1}}, - }, - "device": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "id": {"type": "string", "default": "cuda"}, - "mem_limit": {"type": "integer", "min": 0, "default": 0}, - }, - }, - "distributed": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "world_rank": {"type": "integer", "min": 0, "default": 0}, - "world_size": {"type": "integer", "min": 1, "default": 1}, - "local_rank": {"type": "integer", "min": 0, "default": 0}, - "data_parallel_size": {"type": "integer", "min": 1, "default": 1}, - "horizontal_parallel_size": {"type": "integer", "min": 1, "default": 1}, - "pipeline_parallel": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "pipeline_parallel_size": {"type": "integer", "min": 1, "default": 1}, - "num_pipeline_micro_batches": {"type": "integer", "min": 1, "default": 1}, - "pipeline_cut_info_string": {"type": "string", "default": ""}, - "sliced_schema": { - "type": "dict", - "default_setter": lambda _: {}, - "keysrules": {"type": "string"}, - "valuesrules": {"type": "list", "schema": {"type": "integer"}}, - }, - "sliced_axes": { - "type": "dict", - "default_setter": lambda _: {}, - "keysrules": {"type": "string"}, - "valuesrules": {"type": "integer"}, - }, - "sliced_tensor_names": {"type": "list", "schema": {"type": "string"}, "default": []}, - }, - }, - "allreduce_post_accumulation": {"type": "boolean", "default": False}, - "deepspeed_zero_optimization": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "stage": {"type": "integer", "min": 0, "max": 1, "default": 0}, - }, - }, - "enable_adasum": {"type": "boolean", "default": False}, - }, - }, - "lr_scheduler": {"type": "lr_scheduler", "nullable": True, "default": None}, - "mixed_precision": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "enabled": {"type": "boolean", "default": False}, - "loss_scaler": {"type": "loss_scaler", "nullable": True, "default": None}, - }, - }, - "graph_transformer": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "attn_dropout_recompute": {"type": "boolean", "default": False}, - "gelu_recompute": {"type": "boolean", "default": False}, - "transformer_layer_recompute": {"type": "boolean", "default": False}, - "number_recompute_layers": {"type": "integer", "min": 0, "default": 0}, - "propagate_cast_ops_config": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "strategy": { - "type": "propagate_cast_ops_strategy", - "nullable": True, - "default": PropagateCastOpsStrategy.FLOOD_FILL, - }, - "level": {"type": "integer", "min": -1, "default": 1}, - "allow": {"type": "list", "schema": {"type": "string"}, "default": []}, - }, - }, - }, - }, - "utils": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "frozen_weights": {"type": "list", "default": []}, - "grad_norm_clip": {"type": "boolean", "default": True}, - "memory_efficient_gradient": {"type": "boolean", "default": False}, - "run_symbolic_shape_infer": {"type": "boolean", "default": False}, - }, - }, - "debug": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "deterministic_compute": {"type": "boolean", "default": False}, - "check_model_export": {"type": "boolean", "default": False}, - "graph_save_paths": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "model_after_graph_transforms_path": {"type": "string", "default": ""}, - "model_with_gradient_graph_path": {"type": "string", "default": ""}, - "model_with_training_graph_path": {"type": "string", "default": ""}, - "model_with_training_graph_after_optimization_path": {"type": "string", "default": ""}, - }, - }, - }, - }, - "_internal_use": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "enable_internal_postprocess": {"type": "boolean", "default": True}, - "extra_postprocess": {"check_with": _check_is_callable, "nullable": True, "default": None}, - "onnx_opset_version": {"type": "integer", "min": 12, "max": 14, "default": 14}, - "enable_onnx_contrib_ops": {"type": "boolean", "default": True}, - }, - }, - "provider_options": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "allow_unknown": True, - "schema": {}, - }, - "session_options": {"type": "session_options", "nullable": True, "default": None}, -} diff --git a/orttraining/orttraining/python/training/postprocess.py b/orttraining/orttraining/python/training/postprocess.py deleted file mode 100644 index 6c2adb6af7978..0000000000000 --- a/orttraining/orttraining/python/training/postprocess.py +++ /dev/null @@ -1,478 +0,0 @@ -import os.path # noqa: F401 -import struct -import sys # noqa: F401 - -import numpy as np # noqa: F401 -import onnx -from onnx import * # noqa: F403 -from onnx import helper, numpy_helper # noqa: F401 - - -def run_postprocess(model): - # this post pass is not required for pytorch >= 1.5 - # where add_node_name in torch.onnx.export is default to True - model = add_name(model) - - # this post pass is not required for pytorch > 1.6 - model = fuse_softmaxNLL_to_softmaxCE(model) - - model = fix_expand_shape(model) - model = fix_expand_shape_pt_1_5(model) - return model - - -def find_input_node(model, arg): - result = [] - for node in model.graph.node: - for output in node.output: - if output == arg: - result.append(node) - return result[0] if len(result) == 1 else None - - -def find_output_node(model, arg): - result = [] - for node in model.graph.node: - for input in node.input: - if input == arg: - result.append(node) - return result[0] if len(result) == 1 else result - - -def add_name(model): - i = 0 - for node in model.graph.node: - node.name = "%s_%d" % (node.op_type, i) - i += 1 - return model - - -# Expand Shape PostProcess - - -def fix_expand_shape(model): - expand_nodes = [n for n in model.graph.node if n.op_type == "Expand"] - model_inputs_names = [i.name for i in model.graph.input] - - for expand_node in expand_nodes: - shape = find_input_node(model, expand_node.input[1]) - if shape.op_type == "Shape": - # an expand subgraph - # Input Input2 - # | | - # | Shape - # | | - # |__ __| - # | | - # Expand - # | - # output - # - # Only if Input2 is one of the model inputs, assign Input2's shape to output of expand. - shape_input_name = shape.input[0] - if shape_input_name in model_inputs_names: - index = model_inputs_names.index(shape_input_name) - expand_out = model.graph.value_info.add() - expand_out.name = expand_node.output[0] - expand_out.type.CopyFrom(model.graph.input[index].type) - return model - - -def fix_expand_shape_pt_1_5(model): - # expand subgraph - # Constant - # + - # ConstantOfShape - # | + | - # | + | - # (Reshape subgraph) Mul | - # |___ _________| | - # + | | | - # + Equal | - # +++++|++++++++++++++|++ - # |____________ | + - # | | + - # (subgraph) Where - # | | - # |_____ ___________| - # | | - # Expand - # | - # output - # - # where the Reshape subgraph is - # - # Input - # | | - # | |___________________ - # | | - # Shape Constant Shape Constant - # | ______| | ______| - # | | | | - # Gather Gather - # | | - # Unsqueeze Unsqueeze - # | | - # | ..Number of dims.. | - # | _________________| - # |...| - # Concat Constant - # | | - # |______ __________________| - # | | - # Reshape - # | - # output - # - # This pass will copy Input's shape to the output of Expand. - expand_nodes = [n for n in model.graph.node if n.op_type == "Expand"] - model_inputs_names = [i.name for i in model.graph.input] - - for expand_node in expand_nodes: - n_where = find_input_node(model, expand_node.input[1]) - if n_where.op_type != "Where": - continue - - n_equal = find_input_node(model, n_where.input[0]) - n_cos = find_input_node(model, n_where.input[1]) - n_reshape = find_input_node(model, n_where.input[2]) - - if n_equal.op_type != "Equal" or n_cos.op_type != "ConstantOfShape" or n_reshape.op_type != "Reshape": - continue - - n_reshape_e = find_input_node(model, n_equal.input[0]) - n_mul = find_input_node(model, n_equal.input[1]) - if n_reshape_e != n_reshape or n_mul.op_type != "Mul": - continue - - n_cos_m = find_input_node(model, n_mul.input[0]) - n_constant = find_input_node(model, n_mul.input[1]) - if n_cos_m != n_cos or n_constant.op_type != "Constant": - continue - - n_concat = find_input_node(model, n_reshape.input[0]) - n_constant_r = find_input_node(model, n_reshape.input[1]) - if n_concat.op_type != "Concat" or n_constant_r.op_type != "Constant": - continue - - n_input_candidates = [] - for concat_in in n_concat.input: - n_unsqueeze = find_input_node(model, concat_in) - if n_unsqueeze.op_type != "Unsqueeze": - break - n_gather = find_input_node(model, n_unsqueeze.input[0]) - if n_gather.op_type != "Gather": - break - n_shape = find_input_node(model, n_gather.input[0]) - n_constant_g = find_input_node(model, n_gather.input[1]) - if n_shape.op_type != "Shape" or n_constant_g.op_type != "Constant": - break - n_input = n_shape.input[0] - if n_input not in model_inputs_names: - break - n_input_candidates.append(n_input) - - if not n_input_candidates or not all(elem == n_input_candidates[0] for elem in n_input_candidates): - continue - - index = model_inputs_names.index(n_input_candidates[0]) - expand_out = model.graph.value_info.add() - expand_out.name = expand_node.output[0] - expand_out.type.CopyFrom(model.graph.input[index].type) - return model - - -# LayerNorm PostProcess - - -def find_nodes(graph, op_type): - nodes = [] - for node in graph.node: - if node.op_type == op_type: - nodes.append(node) - return nodes - - -def is_type(node, op_type): - if node is None or isinstance(node, list): - return False - return node.op_type == op_type - - -def add_const(model, name, output, t_value=None, f_value=None): - const_node = model.graph.node.add() - const_node.op_type = "Constant" - const_node.name = name - const_node.output.extend([output]) - attr = const_node.attribute.add() - attr.name = "value" - if t_value is not None: - attr.type = 4 - attr.t.CopyFrom(t_value) - else: - attr.type = 1 - attr.f = f_value - return const_node - - -def layer_norm_transform(model): - # DEPRECATED: This pass is no longer needed as the transform is handled at the backend. - # Converting below subgraph - # - # input - # | - # ReduceMean - # | - # Sub Constant - # _||_____ | - # | | | - # | | | - # | (optional) Cast (optional) Cast - # | | | - # | | ____________________| - # | | | - # | Pow - # | | - # | ReduceMean - # | | - # | Add - # | | - # |__ __Sqrt - # | | - # Div (weight) - # | | - # | _____| - # | | - # Mul (bias) - # | | - # | _____| - # | | - # Add - # | - # output - # - # to the below subgraph - # - # input (weight) (bias) - # | | | - # | _______| | - # | | ________________| - # | | | - # LayerNormalization - # | - # output - graph = model.graph - - nodes_ReduceMean = find_nodes(graph, "ReduceMean") # noqa: N806 - - id = 0 - layer_norm_nodes = [] - remove_nodes = [] - for reduce_mean in nodes_ReduceMean: - # check that reduce_mean output is Sub - sub = find_output_node(model, reduce_mean.output[0]) - if not is_type(sub, "Sub"): - continue - - # check that sub output[0] is Div and output[1] is Pow - pow, div = find_output_node(model, sub.output[0]) - if is_type(pow, "Cast"): - # During an update in PyTorch, Cast nodes are inserted between Sub and Pow. - remove_nodes += [pow] - pow = find_output_node(model, pow.output[0]) - if not is_type(pow, "Pow"): - continue - cast_pow = find_input_node(model, pow.input[1]) - if not is_type(cast_pow, "Cast"): - continue - remove_nodes += [cast_pow] - if not is_type(div, "Div") or not is_type(pow, "Pow"): - continue - - # check that pow ouput is ReduceMean - reduce_mean2 = find_output_node(model, pow.output[0]) - if not is_type(reduce_mean2, "ReduceMean"): - continue - - # check that reduce_mean2 output is Add - add = find_output_node(model, reduce_mean2.output[0]) - if not is_type(add, "Add"): - continue - - # check that add output is Sqrt - sqrt = find_output_node(model, add.output[0]) - if not is_type(sqrt, "Sqrt"): - continue - - # check that sqrt output is div - if div != find_output_node(model, sqrt.output[0]): - continue - - # check if div output is Mul - optional_mul = find_output_node(model, div.output[0]) - if not is_type(optional_mul, "Mul"): - optional_mul = None - continue # default bias and weight not supported - - # check if mul output is Add - if optional_mul is not None: - optional_add = find_output_node(model, optional_mul.output[0]) - else: - optional_add = find_output_node(model, div.output[0]) - if not is_type(optional_add, "Add"): - optional_add = None - continue # default bias and weight not supported - - # add nodes to remove_nodes - remove_nodes.extend([reduce_mean, sub, div, pow, reduce_mean2, add, sqrt]) - - # create LayerNorm node - layer_norm_input = [] - layer_norm_output = [] - - layer_norm_input.append(reduce_mean.input[0]) - - if optional_mul is not None: - remove_nodes.append(optional_mul) - weight = optional_mul.input[1] - layer_norm_input.append(weight) - - if optional_add is not None: - remove_nodes.append(optional_add) - bias = optional_add.input[1] - layer_norm_input.append(bias) - - if optional_add is not None: - layer_norm_output.append(optional_add.output[0]) - elif optional_mul is not None: - layer_norm_output.append(optional_mul.output[0]) - else: - layer_norm_output.append(div.output[0]) - - layer_norm_output.append("saved_mean_" + str(id)) - layer_norm_output.append("saved_inv_std_var_" + str(id)) - - epsilon_node = find_input_node(model, add.input[1]) - epsilon = epsilon_node.attribute[0].t.raw_data - epsilon = struct.unpack("f", epsilon)[0] - - layer_norm = helper.make_node( - "LayerNormalization", - layer_norm_input, - layer_norm_output, - "LayerNormalization_" + str(id), - None, - axis=reduce_mean.attribute[0].ints[0], - epsilon=epsilon, - ) - layer_norm_nodes.append(layer_norm) - id += 1 - - # remove orphan constant nodes - for constant in graph.node: - if constant.op_type == "Constant" and constant not in remove_nodes: - is_orphan = True - for out_name in constant.output: - out = find_output_node(model, out_name) - if out not in remove_nodes: - is_orphan = False - if is_orphan: - remove_nodes.append(constant) - - all_nodes = [] - for node in graph.node: - if node not in remove_nodes: - all_nodes.append(node) - - for node in layer_norm_nodes: - all_nodes.append(node) # noqa: PERF402 - - graph.ClearField("node") - graph.node.extend(all_nodes) - return model - - -# Fuse SoftmaxCrossEntropy - - -def fuse_softmaxNLL_to_softmaxCE(onnx_model): # noqa: N802 - # Converting below subgraph - # - # (subgraph) - # | - # LogSoftmax (target) (optional weight) - # | | | - # nll_loss/NegativeLogLikelihoodLoss - # | - # output - # - # to the following - # - # (subgraph) (target) (optional weight) - # | | _____| - # | | | - # SparseSoftmaxCrossEntropy - # | - # output - nll_count = 0 - while True: - nll_count = nll_count + 1 - nll_loss_node = None - nll_loss_node_index = 0 - for nll_loss_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007 - if node.op_type == "nll_loss" or node.op_type == "NegativeLogLikelihoodLoss": - nll_loss_node = node - break - - if nll_loss_node is None: - break - - softmax_node = None - softmax_node_index = 0 - label_input_name = None - weight_input_name = None - for softmax_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007 - if node.op_type == "LogSoftmax": - # has to be connected to nll_loss - if len(nll_loss_node.input) > 2: - weight_input_name = nll_loss_node.input[2] - if node.output[0] == nll_loss_node.input[0]: - softmax_node = node - label_input_name = nll_loss_node.input[1] - break - elif node.output[0] == nll_loss_node.input[1]: - softmax_node = node - label_input_name = nll_loss_node.input[0] - break - else: - if softmax_node is not None: - break - - if softmax_node is None: - break - - # delete nll_loss and LogSoftmax nodes in order - if nll_loss_node_index < softmax_node_index: - del onnx_model.graph.node[softmax_node_index] - del onnx_model.graph.node[nll_loss_node_index] - else: - del onnx_model.graph.node[nll_loss_node_index] - del onnx_model.graph.node[softmax_node_index] - - probability_output_name = softmax_node.output[0] - node = onnx_model.graph.node.add() - inputs = ( - [softmax_node.input[0], label_input_name, weight_input_name] - if weight_input_name - else [softmax_node.input[0], label_input_name] - ) - node.CopyFrom( - onnx.helper.make_node( - "SparseSoftmaxCrossEntropy", - inputs, - [nll_loss_node.output[0], probability_output_name], - "nll_loss_node_" + str(nll_count), - ) - ) - - return onnx_model diff --git a/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py b/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py deleted file mode 100644 index f57f55d14eb1b..0000000000000 --- a/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py +++ /dev/null @@ -1,144 +0,0 @@ -import sys -import threading -import time - - -class OutputGrabber: - """ - Class used to grab standard output or another stream. - """ - - escape_char = "\b" - - def __init__(self, stream=None, threaded=False): - self.origstream = stream - self.threaded = threaded - if self.origstream is None: - self.origstream = sys.stdout - self.origstreamfd = self.origstream.fileno() - self.capturedtext = "" - # Create a pipe so the stream can be captured: - self.pipe_out, self.pipe_in = os.pipe() - - def __enter__(self): - self.start() - return self - - def __exit__(self, type, value, traceback): - self.stop() - - def start(self): - """ - Start capturing the stream data. - """ - self.capturedtext = "" - # Save a copy of the stream: - self.streamfd = os.dup(self.origstreamfd) - # Replace the original stream with our write pipe: - os.dup2(self.pipe_in, self.origstreamfd) - if self.threaded: - # Start thread that will read the stream: - self.workerThread = threading.Thread(target=self.readOutput) - self.workerThread.start() - # Make sure that the thread is running and os.read() has executed: - time.sleep(0.01) - - def stop(self): - """ - Stop capturing the stream data and save the text in `capturedtext`. - """ - # Print the escape character to make the readOutput method stop: - self.origstream.write(self.escape_char) - # Flush the stream to make sure all our data goes in before - # the escape character: - self.origstream.flush() - if self.threaded: - # wait until the thread finishes so we are sure that - # we have until the last character: - self.workerThread.join() - else: - self.readOutput() - # Close the pipe: - os.close(self.pipe_in) - os.close(self.pipe_out) - # Restore the original stream: - os.dup2(self.streamfd, self.origstreamfd) - # Close the duplicate stream: - os.close(self.streamfd) - - def readOutput(self): - """ - Read the stream data (one byte at a time) - and save the text in `capturedtext`. - """ - while True: - char = os.read(self.pipe_out, 1).decode(self.origstream.encoding) - if not char or self.escape_char in char: - break - self.capturedtext += char - - -import os # noqa: E402 -import unittest # noqa: E402 - -import numpy as np # noqa: E402, F401 -import torch # noqa: E402 -import torch.nn as nn # noqa: E402 -import torch.nn.functional as F # noqa: E402 - -from onnxruntime.capi import _pybind_state as torch_ort_eager # noqa: E402, F401 -from onnxruntime.training import optim, orttrainer, orttrainer_options # noqa: E402, F401 - - -def my_loss(x, target): - return F.nll_loss(F.log_softmax(x, dim=1), target) - - -class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, x, target): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - return my_loss(out, target) - - -class OrtEPTests(unittest.TestCase): - def test_external_graph_transformer_triggering(self): - input_size = 784 - hidden_size = 500 - num_classes = 10 - batch_size = 128 - model = NeuralNet(input_size, hidden_size, num_classes) - - model_desc = { - "inputs": [ - ("x", [batch_size, input_size]), - ( - "target", - [ - batch_size, - ], - ), - ], - "outputs": [("loss", [], True)], - } - optim_config = optim.SGDConfig() - opts = orttrainer.ORTTrainerOptions({"device": {"id": "cpu"}}) - model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - # because orttrainer is lazy initialized, feed in a random data to trigger the graph transformer - data = torch.rand(batch_size, input_size) - target = torch.randint(0, 10, (batch_size,)) - - with OutputGrabber() as out: - model.train_step(data, target) - assert "******************Trigger Customized Graph Transformer: MyGraphTransformer!" in out.capturedtext - - -if __name__ == "__main__": - unittest.main() diff --git a/orttraining/orttraining/test/external_transformer/test_exeternal_transformers/test_external_transformers.cc b/orttraining/orttraining/test/external_transformer/test_exeternal_transformers/test_external_transformers.cc deleted file mode 100644 index 00e933dd14914..0000000000000 --- a/orttraining/orttraining/test/external_transformer/test_exeternal_transformers/test_external_transformers.cc +++ /dev/null @@ -1,35 +0,0 @@ -#include "core/optimizer/rewrite_rule.h" -#include "orttraining/core/optimizer/graph_transformer_registry.h" -#include "onnx/defs/schema.h" -#include -#include - -namespace onnxruntime { -namespace training { - -class MyRewriteRule : public RewriteRule { - public: - MyRewriteRule() noexcept - : RewriteRule("MyRewriteRule") { - } - std::vector TargetOpTypes() const noexcept override { - return {}; - } - - private: - bool SatisfyCondition(const Graph& /*graph*/, const Node& /*node*/, const logging::Logger& /*logger*/) const override { - return true; - } - - Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/, const logging::Logger& /*logger*/) const override { - std::cout << "******************Trigger Customized Graph Transformer: MyGraphTransformer!" << std::endl; - return Status::OK(); - } -}; - -void RegisterTrainingExternalTransformers() { - ONNX_REGISTER_EXTERNAL_REWRITE_RULE(MyRewriteRule, Level1, true); -} - -} // namespace training -} // namespace onnxruntime diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 20b9354d85745..b774fec11cc8d 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -35,6 +35,7 @@ #ifdef ENABLE_TRITON #include "orttraining/core/optimizer/triton_fusion.h" #endif +#include "orttraining/core/optimizer/conv1d_replacement.h" #include @@ -1199,6 +1200,103 @@ TEST_P(QDQFusionTestsParameterized, CheckModelComposition) { ASSERT_EQ(op_to_count_post_fusion["com.microsoft.FakeQuant"], 1); } +TEST_F(GraphTransformationTests, Conv1dReplacement) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); + return Status::OK(); + }; + + for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { + for (auto group : {1, 2}) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); + auto out_channel = 64; + auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); + + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel / group, 1}, {-1.0f, 1.0f}); + auto* conv_output = builder.MakeOutput(); + + auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); + conv_node.AddAttribute("dilations", std::vector{1}); + conv_node.AddAttribute("kernel_shape", std::vector{1}); + conv_node.AddAttribute("strides", std::vector{1}); + conv_node.AddAttribute("group", static_cast(group)); + }; + + auto post_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 0); + // after graph transformation, the graph should have 1 squeeze, 2 split, group matmul, 1 concat + TEST_RETURN_IF_NOT(op_count_map["Squeeze"] == 1); + TEST_RETURN_IF_NOT(op_count_map["Split"] == 2); + TEST_RETURN_IF_NOT(op_count_map["MatMul"] == group); + TEST_RETURN_IF_NOT(op_count_map["Concat"] == 1); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } + } +} + +TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); + return Status::OK(); + }; + + // "group" is 3 so conv not replaced + for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); + auto out_channel = 64; + auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); + + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel / 3, 1}, {-1.0f, 1.0f}); + auto* conv_output = builder.MakeOutput(); + + auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); + conv_node.AddAttribute("dilations", std::vector{1}); + conv_node.AddAttribute("kernel_shape", std::vector{1}); + conv_node.AddAttribute("strides", std::vector{1}); + conv_node.AddAttribute("group", static_cast(3)); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker)); + } + + // "kernel_shape" is not 1 so conv not replaced + for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); + auto out_channel = 64; + auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); + + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel, 1}, {-1.0f, 1.0f}); + auto* conv_output = builder.MakeOutput(); + + auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); + conv_node.AddAttribute("dilations", std::vector{1}); + conv_node.AddAttribute("kernel_shape", std::vector{2}); + conv_node.AddAttribute("strides", std::vector{1}); + conv_node.AddAttribute("group", static_cast(1)); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker)); + } +} + INSTANTIATE_TEST_SUITE_P( QDQFusionTests, QDQFusionTestsParameterized, diff --git a/orttraining/orttraining/test/python/_test_commons.py b/orttraining/orttraining/test/python/_test_commons.py index 1413d59096832..fb7e62551de63 100644 --- a/orttraining/orttraining/test/python/_test_commons.py +++ b/orttraining/orttraining/test/python/_test_commons.py @@ -1,26 +1,7 @@ -import copy -import math import os import subprocess import sys -import numpy as np -import onnx -import torch -from numpy.testing import assert_allclose - -import onnxruntime -from onnxruntime.training import _utils, optim - - -def _single_run(execution_file, scenario, checkopint_dir=None): - cmd = [sys.executable, execution_file] - if scenario: - cmd += ["--scenario", scenario] - if checkopint_dir: - cmd += ["--checkpoint_dir", checkopint_dir] - assert subprocess.call(cmd) == 0 - def is_windows(): return sys.platform.startswith("win") @@ -46,197 +27,3 @@ def run_subprocess(args, cwd=None, capture=False, dll_path=None, shell=False, en if log: log.debug("Subprocess completed. Return code=" + str(completed_process.returncode)) return completed_process - - -def legacy_constant_lr_scheduler(global_step, initial_lr, total_steps, warmup): - num_warmup_steps = warmup * total_steps - if global_step < num_warmup_steps: - new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps)) - else: - new_lr = initial_lr - return new_lr - - -def legacy_cosine_lr_scheduler(global_step, initial_lr, total_steps, warmup, cycles): - num_warmup_steps = warmup * total_steps - if global_step < num_warmup_steps: - new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps)) - else: - progress = float(global_step - num_warmup_steps) / float(max(1, total_steps - num_warmup_steps)) - new_lr = initial_lr * max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(cycles) * 2.0 * progress))) - return new_lr - - -def legacy_linear_lr_scheduler(global_step, initial_lr, total_steps, warmup): - num_warmup_steps = warmup * total_steps - if global_step < num_warmup_steps: - new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps)) - else: - new_lr = initial_lr * max(0.0, float(total_steps - global_step) / float(max(1, total_steps - num_warmup_steps))) - return new_lr - - -def legacy_poly_lr_scheduler(global_step, initial_lr, total_steps, warmup, power, lr_end): - num_warmup_steps = warmup * total_steps - if global_step < num_warmup_steps: - new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps)) - elif global_step > total_steps: - new_lr = lr_end - else: - lr_range = initial_lr - lr_end - decay_steps = total_steps - num_warmup_steps - pct_remaining = 1 - (global_step - num_warmup_steps) / decay_steps - decay = lr_range * pct_remaining**power + lr_end - new_lr = decay - return new_lr - - -def generate_dummy_optim_state(model, optimizer): - np.random.seed(0) - if not (isinstance(optimizer, (optim.AdamConfig, optim.LambConfig))): - return dict() - - moment_keys = ["Moment_1", "Moment_2"] - uc_key = "Update_Count" - step_key = "Step" - shared_state_key = "shared_optimizer_state" - - optim_state = dict() - weight_shape_map = dict() - if isinstance(model, torch.nn.Module): - weight_shape_map = {name: param.size() for name, param in model.named_parameters()} - elif isinstance(model, onnx.ModelProto): - weight_shape_map = {n.name: n.dims for n in model.graph.initializer} - else: - raise ValueError("'model' must be either 'torch.nn.Module' or 'onnx.ModelProto'") - - for weight_name, weight_shape in weight_shape_map.items(): - per_weight_state = dict() - for moment in moment_keys: - per_weight_state[moment] = np.random.uniform(-2, 2, weight_shape).astype(np.float32) - if isinstance(optimizer, optim.AdamConfig): - per_weight_state[uc_key] = np.full([1], 5, dtype=np.int64) - optim_state[weight_name] = copy.deepcopy(per_weight_state) - if isinstance(optimizer, optim.LambConfig): - step_val = np.full([1], 5, dtype=np.int64) - optim_state[shared_state_key] = {step_key: step_val} - return {"optimizer": optim_state, "trainer_options": {"optimizer_name": optimizer.name}} - - -def _load_pytorch_transformer_model(device, dynamic_axes=False, legacy_api=False, data_dir=None): - # Loads external Pytorch TransformerModel into utils - root = "samples" - if not os.path.exists(root): - root = os.path.normpath( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "..", "samples") - ) - if not os.path.exists(root): - raise FileNotFoundError("Unable to find folder 'samples', tried %r." % root) - pytorch_transformer_path = os.path.join(root, "python", "training", "orttrainer", "pytorch_transformer") - pt_model_path = os.path.join(pytorch_transformer_path, "pt_model.py") - pt_model = _utils.import_module_from_file(pt_model_path) - ort_utils_path = os.path.join(pytorch_transformer_path, "ort_utils.py") - ort_utils = _utils.import_module_from_file(ort_utils_path) - utils_path = os.path.join(pytorch_transformer_path, "utils.py") - utils = _utils.import_module_from_file(utils_path) - - # Modeling - model = pt_model.TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device) - my_loss = ort_utils.my_loss - if legacy_api: - if dynamic_axes: - model_desc = ort_utils.legacy_transformer_model_description_dynamic_axes() - else: - model_desc = ort_utils.legacy_transformer_model_description() - else: - if dynamic_axes: - model_desc = ort_utils.transformer_model_description_dynamic_axes() - else: - model_desc = ort_utils.transformer_model_description() - - # Preparing data - train_data, val_data, test_data = utils.prepare_data(device, 20, 20, data_dir) - return model, model_desc, my_loss, utils.get_batch, train_data, val_data, test_data - - -def generate_random_input_from_bart_model_desc(desc, seed=1, device="cuda:0"): - """Generates a sample input for the BART model using the model desc""" - - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - dtype = torch.int64 - vocab_size = 30528 - sample_input = [] - for _index, input in enumerate(desc["inputs"]): - size = [] - for s in input[1]: - if isinstance(s, (int)): - size.append(s) - else: - size.append(1) - sample_input.append(torch.randint(0, vocab_size, tuple(size), dtype=dtype).to(device)) - return sample_input - - -def _load_bart_model(): - bart_onnx_model_path = os.path.join("testdata", "bart_tiny.onnx") - model = onnx.load(bart_onnx_model_path) - batch = 2 - seq_len = 1024 - model_desc = { - "inputs": [ - ( - "src_tokens", - [batch, seq_len], - ), - ( - "prev_output_tokens", - [batch, seq_len], - ), - ( - "target", - [batch * seq_len], - ), - ], - "outputs": [("loss", [], True)], - } - - return model, model_desc - - -def assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint, reshape_states=False): - """Assert that the two ORTTrainer (hierarchical) state dictionaries are very close for all states""" - - assert ("model" in state_dict_pre_checkpoint) == ("model" in state_dict_post_checkpoint) - assert ("optimizer" in state_dict_pre_checkpoint) == ("optimizer" in state_dict_post_checkpoint) - - if "model" in state_dict_pre_checkpoint: - for model_state_key in state_dict_pre_checkpoint["model"]["full_precision"]: - if reshape_states: - assert_allclose( - state_dict_pre_checkpoint["model"]["full_precision"][model_state_key], - state_dict_post_checkpoint["model"]["full_precision"][model_state_key].reshape( - state_dict_pre_checkpoint["model"]["full_precision"][model_state_key].shape - ), - ) - else: - assert_allclose( - state_dict_pre_checkpoint["model"]["full_precision"][model_state_key], - state_dict_post_checkpoint["model"]["full_precision"][model_state_key], - ) - - if "optimizer" in state_dict_pre_checkpoint: - for model_state_key in state_dict_pre_checkpoint["optimizer"]: - for optimizer_state_key in state_dict_pre_checkpoint["optimizer"][model_state_key]: - if reshape_states: - assert_allclose( - state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key], - state_dict_post_checkpoint["optimizer"][model_state_key][optimizer_state_key].reshape( - state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key].shape - ), - ) - else: - assert_allclose( - state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key], - state_dict_post_checkpoint["optimizer"][model_state_key][optimizer_state_key], - ) diff --git a/orttraining/orttraining/test/python/_test_helpers.py b/orttraining/orttraining/test/python/_test_helpers.py index a9a4c7b1cc2ef..8f2a18b5ec00b 100644 --- a/orttraining/orttraining/test/python/_test_helpers.py +++ b/orttraining/orttraining/test/python/_test_helpers.py @@ -1,30 +1,11 @@ import copy import os -import numpy as np import torch from numpy.testing import assert_allclose -from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer -from onnxruntime.training import orttrainer - -try: - from onnxruntime.training.ortmodule import ORTModule - from onnxruntime.training.ortmodule._fallback import ORTModuleInitException - from onnxruntime.training.ortmodule._graph_execution_manager_factory import ( # noqa: F401 - GraphExecutionManagerFactory, - ) -except ImportError: - # Some pipelines do not contain ORTModule - pass -except Exception as e: - from onnxruntime.training.ortmodule._fallback import ORTModuleInitException - - if isinstance(e, ORTModuleInitException): - # ORTModule is present but not ready to run - # That is OK because this file is also used by ORTTrainer tests - pass - raise +from onnxruntime.training.ortmodule import ORTModule +from onnxruntime.training.ortmodule._graph_execution_manager_factory import GraphExecutionManagerFactory # noqa: F401 def is_all_or_nothing_fallback_enabled(model, policy=None): @@ -66,103 +47,6 @@ def assert_model_outputs(output_a, output_b, verbose=False, rtol=1e-7, atol=0): assert_allclose(output_a, output_b, rtol=rtol, atol=atol, err_msg="Model output value mismatch") -def assert_onnx_weights(model_a, model_b, verbose=False, rtol=1e-7, atol=0): - r"""Asserts whether weight difference between models a and b differences are within specified tolerance - - Compares the weights of two different ONNX models (model_a and model_b) - and raises AssertError when they diverge by more than atol or rtol - - Args: - model_a, model_b (ORTTrainer): Two instances of ORTTrainer with the same model structure - verbose (bool, default is False): if True, prints absolute difference for each weight - rtol (float, default is 1e-7): Max relative difference - atol (float, default is 1e-4): Max absolute difference - """ - assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, orttrainer.ORTTrainer) - state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b._training_session.get_state() - assert len(state_dict_a.items()) == len(state_dict_b.items()) - _assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol) - - -def assert_legacy_onnx_weights(model_a, model_b, verbose=False, rtol=1e-7, atol=0): - r"""Asserts whether weight difference between models a and b differences are within specified tolerance - - Compares the weights of a legacy model model_a and experimental model_b model - and raises AssertError when they diverge by more than atol or rtol. - - Args: - model_a (ORTTrainer): Instance of legacy ORTTrainer - model_b (ORTTrainer): Instance of experimental ORTTrainer - verbose (bool, default is False): if True, prints absolute difference for each weight. - rtol (float, default is 1e-7): Max relative difference - atol (float, default is 1e-4): Max absolute difference - """ - assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, Legacy_ORTTrainer) - state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b.session.get_state() - assert len(state_dict_a.items()) == len(state_dict_b.items()) - _assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol) - - -def _assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol): - r"""Asserts whether dicts a and b value differences are within specified tolerance - - Compares the weights of two model's state_dict dicts and raises AssertError - when they diverge by more than atol or rtol - - Args: - model_a (ORTTrainer): Instance of legacy ORTTrainer - model_b (ORTTrainer): Instance of experimental ORTTrainer - verbose (bool, default is False): if True, prints absolute difference for each weight. - rtol (float, default is 1e-7): Max relative difference - atol (float, default is 1e-4): Max absolute difference - """ - - for (a_name, a_val), (_b_name, b_val) in zip(state_dict_a.items(), state_dict_b.items()): - np_a_vals = np.array(a_val).flatten() - np_b_vals = np.array(b_val).flatten() - assert np_a_vals.shape == np_b_vals.shape - if verbose: - print(f"Weight name: {a_name}: absolute difference: {np.abs(np_a_vals-np_b_vals).max()}") - assert_allclose(a_val, b_val, rtol=rtol, atol=atol, err_msg=f"Weight mismatch for {a_name}") - - -def assert_optim_state(expected_state, actual_state, rtol=1e-7, atol=0): - r"""Asserts whether optimizer state differences are within specified tolerance - - Compares the expected and actual optimizer states of dicts and raises AssertError - when they diverge by more than atol or rtol. - The optimizer dict is of the form: - model_weight_name: - { - "Moment_1": moment1_tensor, - "Moment_2": moment2_tensor, - "Update_Count": update_tensor # if optimizer is adam, absent otherwise - }, - ... - "shared_optimizer_state": # if optimizer is shared, absent otherwise. - So far, only lamb optimizer uses this. - { - "step": step_tensor # int array of size 1 - } - - Args: - expected_state (dict(dict())): Expected optimizer state - actual_state (dict(dict())): Actual optimizer state - rtol (float, default is 1e-7): Max relative difference - atol (float, default is 0): Max absolute difference - """ - assert expected_state.keys() == actual_state.keys() - for param_name, a_state in actual_state.items(): - for k, v in a_state.items(): - assert_allclose( - v, - expected_state[param_name][k], - rtol=rtol, - atol=atol, - err_msg=f"Optimizer state mismatch for param {param_name}, key {k}", - ) - - def is_dynamic_axes(model): # Check inputs for inp in model._torch_module._execution_manager(model._is_training())._onnx_models.optimized_model.graph.input: diff --git a/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py b/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py deleted file mode 100644 index d5298cf8e860e..0000000000000 --- a/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py +++ /dev/null @@ -1,325 +0,0 @@ -import os -import unittest - -import torch -import torch.nn as nn -from orttraining_test_bert_postprocess import postprocess_model -from orttraining_test_data_loader import create_ort_test_dataloader -from orttraining_test_transformers import BertForPreTraining, BertModelTest -from orttraining_test_utils import map_optimizer_attributes - -import onnxruntime -from onnxruntime.capi.ort_trainer import ( # noqa: F401 - IODescription, - LossScaler, - ModelDescription, - ORTTrainer, - generate_sample, -) - -torch.manual_seed(1) -onnxruntime.set_seed(1) - - -class Test_PostPasses(unittest.TestCase): # noqa: N801 - def get_onnx_model( - self, model, model_desc, inputs, device, _enable_internal_postprocess=True, _extra_postprocess=None - ): - lr_desc = IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ) - model = ORTTrainer( - model, - None, - model_desc, - "LambOptimizer", - map_optimizer_attributes, - lr_desc, - device, - world_rank=0, - world_size=1, - _opset_version=14, - _enable_internal_postprocess=_enable_internal_postprocess, - _extra_postprocess=_extra_postprocess, - ) - - model.train_step(*inputs) - return model.onnx_model_ - - def count_all_nodes(self, model): - return len(model.graph.node) - - def count_nodes(self, model, node_type): - count = 0 - for node in model.graph.node: - if node.op_type == node_type: - count += 1 - return count - - def find_nodes(self, model, node_type): - nodes = [] - for node in model.graph.node: - if node.op_type == node_type: - nodes.append(node) - return nodes - - def get_name(self, name): - if os.path.exists(name): - return name - rel = os.path.join("testdata", name) - if os.path.exists(rel): - return rel - this = os.path.dirname(__file__) - data = os.path.join(this, "..", "..", "..", "..", "onnxruntime", "test", "testdata") - res = os.path.join(data, name) - if os.path.exists(res): - return res - raise FileNotFoundError(f"Unable to find '{name}' or '{rel}' or '{res}'") - - def test_layer_norm(self): - class LayerNormNet(nn.Module): - def __init__(self, target): - super().__init__() - self.ln_1 = nn.LayerNorm(10) - self.loss = nn.CrossEntropyLoss() - self.target = target - - def forward(self, x): - output1 = self.ln_1(x) - loss = self.loss(output1, self.target) - return loss, output1 - - device = torch.device("cpu") - target = torch.ones(20, 10, 10, dtype=torch.int64).to(device) - model = LayerNormNet(target) - input = torch.randn(20, 5, 10, 10, dtype=torch.float32).to(device) - - input_desc = IODescription("input", [], "float32") - output0_desc = IODescription("output0", [], "float32") - output1_desc = IODescription("output1", [20, 5, 10, 10], "float32") - model_desc = ModelDescription([input_desc], [output0_desc, output1_desc]) - - learning_rate = torch.tensor([1.0000000e00]).to(device) - input_args = [input, learning_rate] - - onnx_model = self.get_onnx_model(model, model_desc, input_args, device) - - count_layer_norm = self.count_nodes(onnx_model, "LayerNormalization") - count_nodes = self.count_all_nodes(onnx_model) - - assert count_layer_norm == 0 - assert count_nodes == 3 - - def test_expand(self): - class ExpandNet(nn.Module): - def __init__(self, target): - super().__init__() - self.loss = nn.CrossEntropyLoss() - self.target = target - self.linear = torch.nn.Linear(2, 2) - - def forward(self, x, x1): - output = x.expand_as(x1) - output = self.linear(output) - output = output + output - loss = self.loss(output, self.target) - return loss, output - - device = torch.device("cpu") - target = torch.ones(5, 5, 2, dtype=torch.int64).to(device) - model = ExpandNet(target).to(device) - - x = torch.randn(5, 3, 1, 2, dtype=torch.float32).to(device) - x1 = torch.randn(5, 3, 5, 2, dtype=torch.float32).to(device) - - input0_desc = IODescription("x", [5, 3, 1, 2], "float32") - input1_desc = IODescription("x1", [5, 3, 5, 2], "float32") - output0_desc = IODescription("output0", [], "float32") - output1_desc = IODescription("output1", [5, 3, 5, 2], "float32") - model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc]) - - learning_rate = torch.tensor([1.0000000e00]).to(device) - input_args = [x, x1, learning_rate] - - onnx_model = self.get_onnx_model(model, model_desc, input_args, device) - - # check that expand output has shape - expand_nodes = self.find_nodes(onnx_model, "Expand") - assert len(expand_nodes) == 1 - - model_info = onnx_model.graph.value_info - assert model_info[0].name == expand_nodes[0].output[0] - assert model_info[0].type == onnx_model.graph.input[1].type - - def test_bert(self): - device = torch.device("cpu") - - model_tester = BertModelTest.BertModelTester(self) - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = model_tester.prepare_config_and_inputs() - - model = BertForPreTraining(config=config) - model.eval() - - loss, prediction_scores, seq_relationship_score = model( - input_ids, - attention_mask=input_mask, - token_type_ids=token_type_ids, - masked_lm_labels=token_labels, - next_sentence_label=sequence_labels, - ) - - model_desc = ModelDescription( - [ - model_tester.input_ids_desc, - model_tester.attention_mask_desc, - model_tester.token_type_ids_desc, - model_tester.masked_lm_labels_desc, - model_tester.next_sentence_label_desc, - ], - [model_tester.loss_desc, model_tester.prediction_scores_desc, model_tester.seq_relationship_scores_desc], - ) - - from collections import namedtuple - - MyArgs = namedtuple( - "MyArgs", "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len" - ) - args = MyArgs( - local_rank=0, - world_size=1, - max_steps=100, - learning_rate=0.00001, - warmup_proportion=0.01, - batch_size=13, - seq_len=7, - ) - - dataset_len = 100 - dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device) - learning_rate = torch.tensor(1.0e0, dtype=torch.float32).to(device) - for b in dataloader: - batch = b - break - learning_rate = torch.tensor([1.00e00]).to(device) - inputs = [*batch, learning_rate] - - onnx_model = self.get_onnx_model(model, model_desc, inputs, device, _extra_postprocess=postprocess_model) - - self._bert_helper(onnx_model) - - def _bert_helper(self, onnx_model): - # count layer_norm - count_layer_norm = self.count_nodes(onnx_model, "LayerNormalization") - assert count_layer_norm == 0 - - # get expand node and check output shape - expand_nodes = self.find_nodes(onnx_model, "Expand") - assert len(expand_nodes) == 1 - - model_info = onnx_model.graph.value_info - assert model_info[0].name == expand_nodes[0].output[0] - assert model_info[0].type == onnx_model.graph.input[0].type - - def test_extra_postpass(self): - def postpass_replace_first_add_with_sub(model): - # this post pass replaces the first Add node with Sub in the model. - # Previous graph - # (subgraph 1) (subgraph 2) - # | | - # | | - # |________ ________| - # | | - # Add - # | - # (subgraph 3) - # - # Post graph - # (subgraph 1) (subgraph 2) - # | | - # | | - # |________ ________| - # | | - # Sub - # | - # (subgraph 3) - add_nodes = [n for n in model.graph.node if n.op_type == "Add"] - add_nodes[0].op_type = "Sub" - - class MultiAdd(nn.Module): - def __init__(self, target): - super().__init__() - self.loss = nn.CrossEntropyLoss() - self.target = target - self.linear = torch.nn.Linear(2, 2, bias=False) - - def forward(self, x, x1): - output = x + x1 - output = output + x - output = output + x1 - output = self.linear(output) - loss = self.loss(output, self.target) - return loss, output - - device = torch.device("cpu") - target = torch.ones(5, 2, dtype=torch.int64).to(device) - model = MultiAdd(target).to(device) - - x = torch.randn(5, 5, 2, dtype=torch.float32).to(device) - x1 = torch.randn(5, 5, 2, dtype=torch.float32).to(device) - - input0_desc = IODescription("x", [5, 5, 2], "float32") - input1_desc = IODescription("x1", [5, 5, 2], "float32") - output0_desc = IODescription("output0", [], "float32") - output1_desc = IODescription("output1", [5, 5, 2], "float32") - model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc]) - - learning_rate = torch.tensor([1.0000000e00]).to(device) - input_args = [x, x1, learning_rate] - - onnx_model = self.get_onnx_model( - model, model_desc, input_args, device, _extra_postprocess=postpass_replace_first_add_with_sub - ) - - # check that extra postpass is called, and called only once. - add_nodes = self.find_nodes(onnx_model, "Add") - sub_nodes = self.find_nodes(onnx_model, "Sub") - assert len(add_nodes) == 2 - assert len(sub_nodes) == 1 - - unprocessed_onnx_model = self.get_onnx_model( - model, model_desc, input_args, device, _extra_postprocess=None, _enable_internal_postprocess=False - ) - # check that the model is unchanged. - add_nodes = self.find_nodes(unprocessed_onnx_model, "Add") - sub_nodes = self.find_nodes(unprocessed_onnx_model, "Sub") - assert len(add_nodes) == 3 - assert len(sub_nodes) == 0 - - processed_onnx_model = self.get_onnx_model( - unprocessed_onnx_model, - model_desc, - input_args, - device, - _extra_postprocess=postpass_replace_first_add_with_sub, - ) - # check that extra postpass is called, and called only once. - add_nodes = self.find_nodes(processed_onnx_model, "Add") - sub_nodes = self.find_nodes(processed_onnx_model, "Sub") - assert len(add_nodes) == 2 - assert len(sub_nodes) == 1 - - -if __name__ == "__main__": - unittest.main(module=__name__, buffer=True) diff --git a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py index 0e7e9d23ee627..5341cd053ac18 100644 --- a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py +++ b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py @@ -43,7 +43,7 @@ def run_ortmodule_ops_tests(cwd, log, transformers_cache): env = get_env_with_transformers_cache(transformers_cache) - command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_onnx_ops_ortmodule.py"] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ortmodule_onnx_ops.py"] run_subprocess(command, cwd=cwd, log=log, env=env).check_returncode() @@ -146,7 +146,7 @@ def run_data_sampler_tests(cwd, log): def run_hooks_tests(cwd, log): log.debug("Running: Data hooks tests") - command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_hooks.py"] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ortmodule_hooks.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() diff --git a/orttraining/orttraining/test/python/orttraining_run_bert_pretrain.py b/orttraining/orttraining/test/python/orttraining_run_bert_pretrain.py deleted file mode 100644 index eea733684f140..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_run_bert_pretrain.py +++ /dev/null @@ -1,801 +0,0 @@ -# ================== -import dataclasses -import datetime -import glob -import json -import logging -import os -import random -import shutil -import unittest -from concurrent.futures import ProcessPoolExecutor -from dataclasses import dataclass, field -from typing import Any, Dict, Optional - -import h5py -import numpy as np -import torch -import torch.distributed as dist -from torch.utils.data import DataLoader, Dataset, RandomSampler -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm -from transformers import BertConfig, BertForPreTraining, HfArgumentParser - -import onnxruntime as ort - -# need to override torch.onnx.symbolic_opset12.nll_loss to handle ignore_index == -100 cases. -# the fix for ignore_index == -100 cases is already in pytorch master. -# however to use current torch master is causing computation changes in many tests. -# eventually we will use pytorch with fixed nll_loss once computation -# issues are understood and solved. -import onnxruntime.capi.pt_patch -from onnxruntime.training import amp, optim, orttrainer -from onnxruntime.training.checkpoint import aggregate_checkpoints -from onnxruntime.training.optim import LinearWarmupLRScheduler, PolyWarmupLRScheduler # noqa: F401 - -# we cannot make full convergence run in nightly pipeling because of its timeout limit, -# max_steps is still needed to calculate learning rate. force_to_stop_max_steps is used to -# terminate the training before the pipeline run hit its timeout. -force_to_stop_max_steps = 2500 - -logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO -) -logger = logging.getLogger(__name__) - - -def get_rank(): - if not dist.is_available(): - return 0 - if not dist.is_initialized(): - return 0 - return dist.get_rank() - - -def is_main_process(args): - if hasattr(args, "world_rank"): - return args.world_rank in [-1, 0] - else: - return get_rank() == 0 - - -def bert_model_description(config): - vocab_size = config.vocab_size - new_model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", "max_seq_len_in_batch"], - ), - ( - "token_type_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "masked_lm_labels", - ["batch", "max_seq_len_in_batch"], - ), - ( - "next_sentence_label", - [ - "batch", - ], - ), - ], - "outputs": [ - ("loss", [], True), - ( - "prediction_scores", - ["batch", "max_seq_len_in_batch", vocab_size], - ), - ( - "seq_relationship_scores", - ["batch", 2], - ), - ], - } - return new_model_desc - - -def create_pretraining_dataset(input_file, max_pred_length, args): - train_data = pretraining_dataset(input_file=input_file, max_pred_length=max_pred_length) - train_sampler = RandomSampler(train_data) - train_dataloader = DataLoader( - train_data, sampler=train_sampler, batch_size=args.train_batch_size * args.n_gpu, num_workers=0, pin_memory=True - ) - return train_dataloader, input_file - - -class pretraining_dataset(Dataset): # noqa: N801 - def __init__(self, input_file, max_pred_length): - logger.info("pretraining_dataset: %s, max_pred_length: %d", input_file, max_pred_length) - self.input_file = input_file - self.max_pred_length = max_pred_length - f = h5py.File(input_file, "r") - keys = [ - "input_ids", - "input_mask", - "segment_ids", - "masked_lm_positions", - "masked_lm_ids", - "next_sentence_labels", - ] - self.inputs = [np.asarray(f[key][:]) for key in keys] - f.close() - - def __len__(self): - "Denotes the total number of samples" - return len(self.inputs[0]) - - def __getitem__(self, index): - [input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids, next_sentence_labels] = [ - torch.from_numpy(input[index].astype(np.int64)) - if indice < 5 - else torch.from_numpy(np.asarray(input[index].astype(np.int64))) - for indice, input in enumerate(self.inputs) - ] - - # HF model use default ignore_index value (-100) for CrossEntropyLoss - masked_lm_labels = torch.ones(input_ids.shape, dtype=torch.long) * -100 - index = self.max_pred_length - # store number of masked tokens in index - padded_mask_indices = (masked_lm_positions == 0).nonzero() - if len(padded_mask_indices) != 0: - index = padded_mask_indices[0].item() - masked_lm_labels[masked_lm_positions[:index]] = masked_lm_ids[:index] - return [input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels] - - -import argparse # noqa: E402 - - -def parse_arguments(): - parser = argparse.ArgumentParser() - - # batch size test config parameters - parser.add_argument( - "--enable_mixed_precision", - default=False, - action="store_true", - help="Whether to use 16-bit float precision instead of 32-bit", - ) - - parser.add_argument( - "--sequence_length", - default=512, - type=int, - help="The maximum total input sequence length after WordPiece tokenization. \n" - "Sequences longer than this will be truncated, and sequences shorter \n" - "than this will be padded.", - ) - parser.add_argument( - "--max_predictions_per_seq", default=80, type=int, help="The maximum total of masked tokens in input sequence" - ) - parser.add_argument("--max_batch_size", default=32, type=int, help="Total batch size for training.") - - parser.add_argument("--gelu_recompute", default=False, action="store_true") - - parser.add_argument("--attn_dropout_recompute", default=False, action="store_true") - - parser.add_argument("--transformer_layer_recompute", default=False, action="store_true") - - args = parser.parse_args() - return args - - -@dataclass -class PretrainArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. - """ - - input_dir: str = field( - default=None, metadata={"help": "The input data dir. Should contain .hdf5 files for the task"} - ) - - bert_model: str = field( - default=None, - metadata={ - "help": "Bert pre-trained model selected in the list: bert-base-uncased, \ - bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." - }, - ) - - output_dir: str = field( - default=None, metadata={"help": "The output directory where the model checkpoints will be written."} - ) - - cache_dir: str = field( - default="/tmp/bert_pretrain/", - metadata={"help": "The output directory where the model checkpoints will be written."}, - ) - max_seq_length: Optional[int] = field( - default=512, - metadata={ - "help": "The maximum total input sequence length after tokenization. Sequences longer \ - than this will be truncated, sequences shorter will be padded." - }, - ) - - max_predictions_per_seq: Optional[int] = field( - default=80, metadata={"help": "The maximum total of masked tokens in input sequence."} - ) - - train_batch_size: Optional[int] = field(default=32, metadata={"help": "Batch size for training."}) - - learning_rate: Optional[float] = field(default=5e-5, metadata={"help": "The initial learning rate for Lamb."}) - - num_train_epochs: Optional[float] = field( - default=3.0, metadata={"help": "Total number of training epochs to perform."} - ) - - max_steps: Optional[float] = field(default=1000, metadata={"help": "Total number of training steps to perform."}) - - warmup_proportion: Optional[float] = field( - default=0.01, - metadata={ - "help": "Proportion of training to perform linear learning rate warmup for. \ - E.g., 0.1 = 10%% of training." - }, - ) - - local_rank: Optional[int] = field(default=-1, metadata={"help": "local_rank for distributed training on gpus."}) - - world_rank: Optional[int] = field(default=-1) - - world_size: Optional[int] = field(default=1) - - seed: Optional[int] = field(default=42, metadata={"help": "random seed for initialization."}) - - gradient_accumulation_steps: Optional[int] = field( - default=1, metadata={"help": "Number of updates steps to accumualte before performing a backward/update pass."} - ) - - fp16: bool = field(default=False, metadata={"help": "Whether to use 16-bit float precision instead of 32-bit."}) - - gelu_recompute: bool = field( - default=False, metadata={"help": "Whether to enable recomputing Gelu activation output to save memory."} - ) - attn_dropout_recompute: bool = field( - default=False, metadata={"help": "Whether to enable recomputing attention dropout to save memory."} - ) - transformer_layer_recompute: bool = field( - default=False, metadata={"help": "Whether to enable recomputing transformer layerwise to save memory."} - ) - - loss_scale: Optional[float] = field( - default=0.0, metadata={"help": "Loss scaling, positive power of 2 values can improve fp16 convergence."} - ) - - deepspeed_zero_stage: Optional[int] = field(default=0, metadata={"help": "Deepspeed Zero Stage. 0 => disabled"}) - - log_freq: Optional[float] = field(default=1.0, metadata={"help": "frequency of logging loss."}) - - checkpoint_activations: bool = field(default=False, metadata={"help": "Whether to use gradient checkpointing."}) - - resume_from_checkpoint: bool = field( - default=False, metadata={"help": "Whether to resume training from checkpoint."} - ) - - resume_step: Optional[int] = field(default=-1, metadata={"help": "Step to resume training from."}) - - num_steps_per_checkpoint: Optional[int] = field( - default=100, metadata={"help": "Number of update steps until a model checkpoint is saved to disk."} - ) - - save_checkpoint: Optional[bool] = field( - default=False, metadata={"help": "Enable for saving a model checkpoint to disk."} - ) - - init_state_dict: Optional[dict] = field(default=None, metadata={"help": "State to load before training."}) - - phase2: bool = field(default=False, metadata={"help": "Whether to train with seq len 512."}) - - allreduce_post_accumulation: bool = field( - default=False, metadata={"help": "Whether to do allreduces during gradient accumulation steps."} - ) - - allreduce_post_accumulation_fp16: bool = field( - default=False, metadata={"help": "Whether to do fp16 allreduce post accumulation."} - ) - - accumulate_into_fp16: bool = field(default=False, metadata={"help": "Whether to use fp16 gradient accumulators."}) - - phase1_end_step: Optional[int] = field( - default=7038, metadata={"help": "Whether to use fp16 gradient accumulators."} - ) - - tensorboard_dir: Optional[str] = field( - default=None, - ) - - schedule: Optional[str] = field( - default="warmup_poly", - ) - - # this argument is test specific. to run a full bert model will take too long to run. instead, we reduce - # number of hidden layers so that it can show convergence to an extend to help detect any regression. - force_num_hidden_layers: Optional[int] = field( - default=None, metadata={"help": "Whether to use fp16 gradient accumulators."} - ) - - def to_json_string(self): - """ - Serializes this instance to a JSON string. - """ - return json.dumps(dataclasses.asdict(self), indent=2) - - def to_sanitized_dict(self) -> Dict[str, Any]: - """ - Sanitized serialization to use with TensorBoard`s hparams - """ - d = dataclasses.asdict(self) - valid_types = [bool, int, float, str, torch.Tensor] - return {k: v if type(v) in valid_types else str(v) for k, v in d.items()} - - -def setup_training(args): - assert torch.cuda.is_available() - - if args.local_rank == -1: - args.local_rank = 0 - args.world_rank = 0 - - print("args.local_rank: ", args.local_rank) - torch.cuda.set_device(args.local_rank) - device = torch.device("cuda", args.local_rank) - args.n_gpu = 1 - - if args.gradient_accumulation_steps < 1: - raise ValueError( - f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps}, should be >= 1" - ) - if args.train_batch_size % args.gradient_accumulation_steps != 0: - raise ValueError( - "Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible".format( - args.gradient_accumulation_steps, args.train_batch_size - ) - ) - - # args.train_batch_size is per global step (optimization step) batch size - # now make it a per gpu batch size - args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps - args.train_batch_size = args.train_batch_size // args.world_size - - logger.info("setup_training: args.train_batch_size = %d", args.train_batch_size) - return device, args - - -def setup_torch_distributed(world_rank, world_size): - os.environ["RANK"] = str(world_rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12345" - torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=world_rank) - return - - -def prepare_model(args, device): - config = BertConfig.from_pretrained(args.bert_model, cache_dir=args.cache_dir) - - # config.num_hidden_layers = 12 - if args.force_num_hidden_layers: - logger.info("Modifying model config with num_hidden_layers to %d", args.force_num_hidden_layers) - config.num_hidden_layers = args.force_num_hidden_layers - - model = BertForPreTraining(config) - if args.init_state_dict is not None: - model.load_state_dict(args.init_state_dict) - model_desc = bert_model_description(config) - - lr_scheduler = LinearWarmupLRScheduler(total_steps=int(args.max_steps), warmup=args.warmup_proportion) - - loss_scaler = amp.DynamicLossScaler() if args.fp16 else None - - options = orttrainer.ORTTrainerOptions( - { - "batch": {"gradient_accumulation_steps": args.gradient_accumulation_steps}, - "device": {"id": str(device)}, - "mixed_precision": {"enabled": args.fp16, "loss_scaler": loss_scaler}, - "graph_transformer": { - "attn_dropout_recompute": args.attn_dropout_recompute, - "gelu_recompute": args.gelu_recompute, - "transformer_layer_recompute": args.transformer_layer_recompute, - }, - "debug": { - "deterministic_compute": True, - }, - "utils": {"grad_norm_clip": True}, - "distributed": { - "world_rank": max(0, args.local_rank), - "world_size": args.world_size, - "local_rank": max(0, args.local_rank), - "allreduce_post_accumulation": args.allreduce_post_accumulation, - "deepspeed_zero_optimization": {"stage": args.deepspeed_zero_stage}, - "enable_adasum": False, - }, - "lr_scheduler": lr_scheduler, - } - ) - - param_optimizer = list(model.named_parameters()) - no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] - params = [ - { - "params": [n for n, p in param_optimizer if any(no_decay_key in n for no_decay_key in no_decay_keys)], - "alpha": 0.9, - "beta": 0.999, - "lambda": 0.0, - "epsilon": 1e-6, - }, - { - "params": [n for n, p in param_optimizer if not any(no_decay_key in n for no_decay_key in no_decay_keys)], - "alpha": 0.9, - "beta": 0.999, - "lambda": 0.0, - "epsilon": 1e-6, - }, - ] - - optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) - model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=options) - - return model - - -def get_data_file(f_id, world_rank, world_size, files): - num_files = len(files) - if world_size > num_files: - remainder = world_size % num_files - return files[(f_id * world_size + world_rank + remainder * f_id) % num_files] - elif world_size > 1: - return files[(f_id * world_size + world_rank) % num_files] - else: - return files[f_id % num_files] - - -def main(): - parser = HfArgumentParser(PretrainArguments) - args = parser.parse_args_into_dataclasses()[0] - do_pretrain(args) - - -def do_pretrain(args): - if is_main_process(args) and args.tensorboard_dir: - tb_writer = SummaryWriter(log_dir=args.tensorboard_dir) - tb_writer.add_text("args", args.to_json_string()) - tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={}) - else: - tb_writer = None - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - ort.set_seed(args.seed) - - device, args = setup_training(args) - - model = prepare_model(args, device) - - logger.info("Running training: Batch size = %d, initial LR = %f", args.train_batch_size, args.learning_rate) - - average_loss = 0.0 - epoch = 0 - training_steps = 0 - - pool = ProcessPoolExecutor(1) - while True: - files = [ - os.path.join(args.input_dir, f) - for f in os.listdir(args.input_dir) - if os.path.isfile(os.path.join(args.input_dir, f)) and "training" in f - ] - files.sort() - random.shuffle(files) - - f_id = 0 - train_dataloader, data_file = create_pretraining_dataset( - get_data_file(f_id, args.world_rank, args.world_size, files), args.max_predictions_per_seq, args - ) - - for f_id in range(1, len(files)): - logger.info("data file %s" % (data_file)) - - dataset_future = pool.submit( - create_pretraining_dataset, - get_data_file(f_id, args.world_rank, args.world_size, files), - args.max_predictions_per_seq, - args, - ) - - train_iter = tqdm(train_dataloader, desc="Iteration") if is_main_process(args) else train_dataloader - for _step, batch in enumerate(train_iter): - training_steps += 1 - batch = [t.to(device) for t in batch] # noqa: PLW2901 - input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch - - loss, _, _ = model.train_step( - input_ids, input_mask, segment_ids, masked_lm_labels, next_sentence_labels - ) - average_loss += loss.item() - - global_step = model._train_step_info.optimization_step - if training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0: - if is_main_process(args): - divisor = args.log_freq * args.gradient_accumulation_steps - if tb_writer: - lr = model.options.lr_scheduler.get_last_lr()[0] - tb_writer.add_scalar("train/summary/scalar/Learning_Rate", lr, global_step) - if args.fp16: - tb_writer.add_scalar("train/summary/scalar/loss_scale_25", loss, global_step) - # TODO: ORTTrainer to expose all_finite - # tb_writer.add_scalar('train/summary/scalar/all_fp16_gradients_finite_859', all_finite, global_step) - tb_writer.add_scalar("train/summary/total_loss", average_loss / divisor, global_step) - - print(f"Step:{global_step} Average Loss = {average_loss / divisor}") - - if global_step >= args.max_steps or global_step >= force_to_stop_max_steps: - if tb_writer: - tb_writer.close() - - if global_step >= args.max_steps: - if args.save_checkpoint: - model.save_checkpoint(os.path.join(args.output_dir, f"checkpoint-{args.world_rank}.ortcp")) - final_loss = average_loss / (args.log_freq * args.gradient_accumulation_steps) - return final_loss - - average_loss = 0 - - del train_dataloader - - train_dataloader, data_file = dataset_future.result(timeout=None) - - epoch += 1 - - -def generate_tensorboard_logdir(root_dir): - current_date_time = datetime.datetime.today() - - dt_string = current_date_time.strftime("BERT_pretrain_%y_%m_%d_%I_%M_%S") - return os.path.join(root_dir, dt_string) - - -class ORTBertPretrainTest(unittest.TestCase): - def setUp(self): - self.output_dir = "/bert_data/hf_data/test_out/bert_pretrain_results" - self.bert_model = "bert-base-uncased" - self.local_rank = -1 - self.world_rank = -1 - self.world_size = 1 - self.max_steps = 300000 - self.learning_rate = 5e-4 - self.max_seq_length = 512 - self.max_predictions_per_seq = 20 - self.input_dir = "/bert_data/hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train" - self.train_batch_size = 4096 - self.gradient_accumulation_steps = 64 - self.fp16 = True - self.allreduce_post_accumulation = True - self.tensorboard_dir = "/bert_data/hf_data/test_out" - - def test_pretrain_throughput(self, process_args=None): - if process_args.sequence_length == 128: - input_dir = "/bert_data/hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train" - else: - input_dir = "/bert_data/hdf5_lower_case_1_seq_len_512_max_pred_80_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train" - - print("process_args.enable_mixed_precision: ", process_args.enable_mixed_precision) - print("process_args.sequence_length: ", process_args.sequence_length) - print("process_args.max_batch_size: ", process_args.max_batch_size) - print("process_args.max_predictions_per_seq: ", process_args.max_predictions_per_seq) - print("process_args.gelu_recompute: ", process_args.gelu_recompute) - print("process_args.attn_dropout_recompute: ", process_args.attn_dropout_recompute) - print("process_args.transformer_layer_recompute: ", process_args.transformer_layer_recompute) - - args = PretrainArguments( - input_dir=input_dir, - output_dir="/bert_data/hf_data/test_out/bert_pretrain_results", - bert_model="bert-large-uncased", - local_rank=self.local_rank, - world_rank=self.world_rank, - world_size=self.world_size, - max_steps=10, - learning_rate=5e-4, - max_seq_length=process_args.sequence_length, - max_predictions_per_seq=process_args.max_predictions_per_seq, - train_batch_size=process_args.max_batch_size, - gradient_accumulation_steps=1, - fp16=process_args.enable_mixed_precision, - gelu_recompute=process_args.gelu_recompute, - attn_dropout_recompute=process_args.attn_dropout_recompute, - transformer_layer_recompute=process_args.transformer_layer_recompute, - allreduce_post_accumulation=True, - # TODO: remove - force_num_hidden_layers=2, - ) - do_pretrain(args) - - def test_pretrain_convergence(self): - args = PretrainArguments( - output_dir=self.output_dir, - bert_model=self.bert_model, - local_rank=self.local_rank, - world_rank=self.world_rank, - world_size=self.world_size, - max_steps=self.max_steps, - learning_rate=self.learning_rate, - max_seq_length=self.max_seq_length, - max_predictions_per_seq=self.max_predictions_per_seq, - train_batch_size=self.train_batch_size, - gradient_accumulation_steps=self.gradient_accumulation_steps, - input_dir=self.input_dir, - fp16=self.fp16, - allreduce_post_accumulation=self.allreduce_post_accumulation, - force_num_hidden_layers=self.force_num_hidden_layers, - tensorboard_dir=generate_tensorboard_logdir("/bert_data/hf_data/test_out/"), - ) - final_loss = do_pretrain(args) - return final_loss - - def test_pretrain_zero(self): - assert self.world_size > 0, "ZeRO test requires a distributed run." - setup_torch_distributed(self.world_rank, self.world_size) - per_gpu_batch_size = 32 - optimization_batch_size = per_gpu_batch_size * self.world_size # set to disable grad accumulation - - self.train_batch_size = optimization_batch_size - self.gradient_accumulation_steps = 1 - self.deepspeed_zero_stage = 1 - self.force_num_hidden_layers = 2 - self.max_seq_length = 32 - self.output_dir = "./bert_pretrain_ckpt" - if self.world_rank == 0: - if os.path.isdir(self.output_dir): - shutil.rmtree(self.output_dir) - os.makedirs(self.output_dir, exist_ok=True) - - torch.distributed.barrier() - - assert os.path.exists(self.output_dir) - - # run a few optimization steps - self.max_steps = 200 - args = PretrainArguments( - output_dir=self.output_dir, - bert_model=self.bert_model, - local_rank=self.local_rank, - world_rank=self.world_rank, - world_size=self.world_size, - max_steps=self.max_steps, - learning_rate=self.learning_rate, - max_seq_length=self.max_seq_length, - max_predictions_per_seq=self.max_predictions_per_seq, - train_batch_size=self.train_batch_size, - gradient_accumulation_steps=self.gradient_accumulation_steps, - input_dir=self.input_dir, - fp16=self.fp16, - allreduce_post_accumulation=self.allreduce_post_accumulation, - force_num_hidden_layers=self.force_num_hidden_layers, - deepspeed_zero_stage=self.deepspeed_zero_stage, - save_checkpoint=True, - ) - do_pretrain(args) - - # ensure all workers reach this point before loading the checkpointed state - torch.distributed.barrier() - - # on rank 0, load the trained state - if args.world_rank == 0: - checkpoint_files = glob.glob(os.path.join(self.output_dir, "checkpoint*.ortcp")) - args.init_state_dict = aggregate_checkpoints(checkpoint_files, pytorch_format=True) - - torch.distributed.barrier() - - # run a single step to get the loss, on rank 0 should be lesser than starting loss - args.save_checkpoint = False - args.max_steps = 1 - args.deepspeed_zero_stage = 0 - final_loss = do_pretrain(args) - return final_loss - - -if __name__ == "__main__": - import sys - - logger.warning("sys.argv: %s", sys.argv) - # usage: - # data parallel training - # mpirun -n 4 python orttraining_run_bert_pretrain.py - # - # single gpu: - # python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_throughput - # [batch size test arguments] - # python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_convergence - # - # pytorch.distributed.launch will not work because ort backend requires MPI to broadcast ncclUniqueId - # calling unpublished get_mpi_context_xxx to get rank/size numbers. - try: - # In case ORT is not built with MPI/NCCL, there are no get_mpi_context_xxx internal apis. - from onnxruntime.capi._pybind_state import get_mpi_context_local_size # noqa: F401 - from onnxruntime.capi._pybind_state import get_mpi_context_world_rank # noqa: F401 - from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_world_size - - has_get_mpi_context_internal_api = True - except ImportError: - has_get_mpi_context_internal_api = False - pass - if has_get_mpi_context_internal_api and get_mpi_context_world_size() > 1: - world_size = get_mpi_context_world_size() - print("get_mpi_context_world_size(): ", world_size) - local_rank = get_mpi_context_local_rank() - - if local_rank == 0: - print("================================================================> os.getpid() = ", os.getpid()) - - test = ORTBertPretrainTest() - test.setUp() - test.local_rank = local_rank - test.world_rank = local_rank - test.world_size = world_size - - if len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_zero": - logger.info("running ORTBertPretrainTest.test_pretrain_zero()...") - final_loss = test.test_pretrain_zero() - logger.info("ORTBertPretrainTest.test_pretrain_zero() rank = %i final loss = %f", local_rank, final_loss) - if local_rank == 0: - test.assertLess(final_loss, 10.2) - else: - test.assertGreater(final_loss, 11.0) - logger.info("ORTBertPretrainTest.test_pretrain_zero() passed") - elif len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_convergence": - logger.info("running ORTBertPretrainTest.test_pretrain_convergence()...") - test.max_steps = 200 - test.force_num_hidden_layers = 8 - final_loss = test.test_pretrain_convergence() - logger.info("ORTBertPretrainTest.test_pretrain_convergence() final loss = %f", final_loss) - test.assertLess(final_loss, 8.5) - logger.info("ORTBertPretrainTest.test_pretrain_convergence() passed") - else: - # https://microsoft.sharepoint.com/teams/ONNX2/_layouts/15/Doc.aspx?sourcedoc={170774be-e1c6-4f8b-a3ae-984f211fe410}&action=edit&wd=target%28ONNX%20Training.one%7C8176133b-c7cb-4ef2-aa9d-3fdad5344c40%2FGitHub%20Master%20Merge%20Schedule%7Cb67f0db1-e3a0-4add-80a6-621d67fd8107%2F%29 - # to make equivalent args for cpp convergence test - test.max_seq_length = 128 - test.max_predictions_per_seq = 20 - test.gradient_accumulation_steps = 16 - - # cpp_batch_size (=64) * grad_acc * world_size - test.train_batch_size = 64 * test.gradient_accumulation_steps * test.world_size - test.max_steps = 300000 - - test.force_num_hidden_layers = None - - # already using Adam (e.g. AdamConfig) - test.learning_rate = 5e-4 - test.warmup_proportion = 0.1 - - final_loss = test.test_pretrain_convergence() - logger.info("ORTBertPretrainTest.test_pretrain_convergence() final loss = %f", final_loss) - else: - # unittest does not accept user defined arguments - # we need to run this script with user defined arguments - if len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_throughput": - run_test_pretrain_throughput, run_test_pretrain_convergence = True, False - sys.argv.remove("ORTBertPretrainTest.test_pretrain_throughput") - elif len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_convergence": - run_test_pretrain_throughput, run_test_pretrain_convergence = False, True - sys.argv.remove("ORTBertPretrainTest.test_pretrain_convergence") - else: - run_test_pretrain_throughput, run_test_pretrain_convergence = True, True - process_args = parse_arguments() - test = ORTBertPretrainTest() - test.setUp() - - if run_test_pretrain_throughput: - logger.info("running single GPU ORTBertPretrainTest.test_pretrain_throughput()...") - test.test_pretrain_throughput(process_args) - logger.info("single GPU ORTBertPretrainTest.test_pretrain_throughput() passed") - - # unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_run_frontend_batch_size_test.py b/orttraining/orttraining/test/python/orttraining_run_frontend_batch_size_test.py deleted file mode 100644 index 3e2d1a7154bfd..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_run_frontend_batch_size_test.py +++ /dev/null @@ -1,67 +0,0 @@ -import collections -import subprocess -import sys - -Config = collections.namedtuple( - "Config", - [ - "enable_mixed_precision", - "sequence_length", - "max_batch_size", - "max_predictions_per_seq", - "gelu_recompute", - "attn_dropout_recompute", - "transformer_layer_recompute", - ], -) - -configs = [ - Config(True, 128, 46, 20, False, False, False), - Config(True, 512, 8, 80, False, False, False), - Config(False, 128, 26, 20, False, False, False), - Config(False, 512, 4, 80, False, False, False), - Config(True, 128, 50, 20, True, False, False), - Config(True, 128, 50, 20, False, True, False), - Config(True, 128, 76, 20, False, False, True), - Config(True, 512, 8, 80, True, False, False), - Config(True, 512, 9, 80, False, True, False), - Config(True, 512, 15, 80, False, False, True), -] - - -def run_with_config(config): - print( - "##### testing name - {}-{} #####".format( - "fp16" if config.enable_mixed_precision else "fp32", config.sequence_length - ) - ) - print("gelu_recompute: ", config.gelu_recompute) - print("attn_dropout_recompute: ", config.attn_dropout_recompute) - print("transformer_layer_recompute: ", config.transformer_layer_recompute) - - cmds = [ - sys.executable, - "orttraining_run_bert_pretrain.py", - "ORTBertPretrainTest.test_pretrain_throughput", - "--sequence_length", - str(config.sequence_length), - "--max_batch_size", - str(config.max_batch_size), - "--max_predictions_per_seq", - str(config.max_predictions_per_seq), - ] - if config.enable_mixed_precision: - cmds.append("--enable_mixed_precision") - if config.gelu_recompute: - cmds.append("--gelu_recompute") - if config.attn_dropout_recompute: - cmds.append("--attn_dropout_recompute") - if config.transformer_layer_recompute: - cmds.append("--transformer_layer_recompute") - - # access to azure storage shared disk is much slower so we need a longer timeout. - subprocess.run(cmds, timeout=1200).check_returncode() # noqa: PLW1510 - - -for config in configs: - run_with_config(config) diff --git a/orttraining/orttraining/test/python/orttraining_run_glue.py b/orttraining/orttraining/test/python/orttraining_run_glue.py deleted file mode 100644 index 794e2f8cc7240..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_run_glue.py +++ /dev/null @@ -1,323 +0,0 @@ -# adapted from run_glue.py of huggingface transformers - -import dataclasses # noqa: F401 -import logging -import os -import unittest -from dataclasses import dataclass, field -from typing import Dict, Optional - -import numpy as np -from numpy.testing import assert_allclose -from transformers import ( - AutoConfig, - AutoModelForSequenceClassification, - AutoTokenizer, - EvalPrediction, - GlueDataset, - GlueDataTrainingArguments, - TrainingArguments, - glue_compute_metrics, - glue_output_modes, - glue_tasks_num_labels, - set_seed, -) - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer # noqa: F401 - -try: - from onnxruntime.capi._pybind_state import get_mpi_context_local_size # noqa: F401 - from onnxruntime.capi._pybind_state import get_mpi_context_world_rank # noqa: F401 - from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_world_size - - has_get_mpi_context_internal_api = True -except ImportError: - has_get_mpi_context_internal_api = False - pass - - -import torch # noqa: F401 -from orttraining_transformer_trainer import ORTTransformerTrainer - -logger = logging.getLogger(__name__) - - -def verify_old_and_new_api_are_equal(results_per_api): - new_api_results = results_per_api[True] - old_api_results = results_per_api[False] - for key in new_api_results: - assert_allclose(new_api_results[key], old_api_results[key]) - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. - """ - - model_name_or_path: str = field(metadata={"help": "model identifier from huggingface.co/models"}) - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - cache_dir: Optional[str] = field( - default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} - ) - - -class ORTGlueTest(unittest.TestCase): - def setUp(self): - # configurations not to be changed accoss tests - self.max_seq_length = 128 - self.train_batch_size = 8 - self.learning_rate = 2e-5 - self.num_train_epochs = 3.0 - self.local_rank = -1 - self.world_size = 1 - self.overwrite_output_dir = True - self.gradient_accumulation_steps = 1 - self.data_dir = "/bert_data/hf_data/glue_data/" - self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "glue_test_output/") - self.cache_dir = "/tmp/glue/" - self.logging_steps = 10 - - def test_roberta_with_mrpc(self): - expected_acc = 0.85 - expected_f1 = 0.88 - expected_loss = 0.35 - results = self.run_glue(model_name="roberta-base", task_name="MRPC", fp16=False) - - assert results["acc"] >= expected_acc - assert results["f1"] >= expected_f1 - assert results["loss"] <= expected_loss - - def test_roberta_fp16_with_mrpc(self): - expected_acc = 0.87 - expected_f1 = 0.90 - expected_loss = 0.33 - - results = self.run_glue(model_name="roberta-base", task_name="MRPC", fp16=True) - - assert results["acc"] >= expected_acc - assert results["f1"] >= expected_f1 - assert results["loss"] <= expected_loss - - def test_bert_with_mrpc(self): - if self.local_rank == -1: - expected_acc = 0.83 - expected_f1 = 0.88 - expected_loss = 0.44 - elif self.local_rank == 0: - expected_acc = 0.81 - expected_f1 = 0.86 - expected_loss = 0.44 - - results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=False) - - if self.local_rank in [-1, 0]: - assert results["acc"] >= expected_acc - assert results["f1"] >= expected_f1 - assert results["loss"] <= expected_loss - - def test_bert_fp16_with_mrpc(self): - expected_acc = 0.84 - expected_f1 = 0.88 - expected_loss = 0.46 - - results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=True) - - assert results["acc"] >= expected_acc - assert results["f1"] >= expected_f1 - assert results["loss"] <= expected_loss - - def model_to_desc(self, model_name, model): - if model_name.startswith("bert") or model_name.startswith("xlnet"): - model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", "max_seq_len_in_batch"], - ), - ( - "token_type_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "labels", - [ - "batch", - ], - ), - ], - "outputs": [("loss", [], True), ("logits", ["batch", 2])], - } - elif model_name.startswith("roberta"): - model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", "max_seq_len_in_batch"], - ), - ( - "labels", - [ - "batch", - ], - ), - ], - "outputs": [("loss", [], True), ("logits", ["batch", 2])], - } - else: - raise RuntimeError(f"unsupported base model name {model_name}.") - - return model_desc - - def run_glue(self, model_name, task_name, fp16): - model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir) - data_args = GlueDataTrainingArguments( - task_name=task_name, data_dir=os.path.join(self.data_dir, task_name), max_seq_length=self.max_seq_length - ) - - training_args = TrainingArguments( - output_dir=os.path.join(self.output_dir, task_name), - do_train=True, - do_eval=True, - per_gpu_train_batch_size=self.train_batch_size, - learning_rate=self.learning_rate, - num_train_epochs=self.num_train_epochs, - local_rank=self.local_rank, - overwrite_output_dir=self.overwrite_output_dir, - gradient_accumulation_steps=self.gradient_accumulation_steps, - fp16=fp16, - logging_steps=self.logging_steps, - ) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, - ) - logger.warning( - "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", - training_args.local_rank, - training_args.device, - training_args.n_gpu, - bool(training_args.local_rank != -1), - training_args.fp16, - ) - logger.info("Training/evaluation parameters %s", training_args) - - set_seed(training_args.seed) - onnxruntime.set_seed(training_args.seed) - - try: - num_labels = glue_tasks_num_labels[data_args.task_name] - output_mode = glue_output_modes[data_args.task_name] - except KeyError: - raise ValueError("Task not found: %s" % (data_args.task_name)) # noqa: B904 - - config = AutoConfig.from_pretrained( - model_args.config_name if model_args.config_name else model_args.model_name_or_path, - num_labels=num_labels, - finetuning_task=data_args.task_name, - cache_dir=model_args.cache_dir, - ) - tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - ) - - model = AutoModelForSequenceClassification.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - ) - - train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None - - eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") if training_args.do_eval else None - - def compute_metrics(p: EvalPrediction) -> Dict: - if output_mode == "classification": - preds = np.argmax(p.predictions, axis=1) - elif output_mode == "regression": - preds = np.squeeze(p.predictions) - return glue_compute_metrics(data_args.task_name, preds, p.label_ids) - - model_desc = self.model_to_desc(model_name, model) - # Initialize the ORTTrainer within ORTTransformerTrainer - trainer = ORTTransformerTrainer( - model=model, - model_desc=model_desc, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - compute_metrics=compute_metrics, - world_size=self.world_size, - ) - - # Training - if training_args.do_train: - trainer.train() - trainer.save_model() - - # Evaluation - results = {} - if training_args.do_eval and training_args.local_rank in [-1, 0]: - logger.info("*** Evaluate ***") - - result = trainer.evaluate() - - logger.info(f"***** Eval results {data_args.task_name} *****") - for key, value in result.items(): - logger.info(" %s = %s", key, value) - - results.update(result) - - return results - - -if __name__ == "__main__": - if has_get_mpi_context_internal_api: - local_rank = get_mpi_context_local_rank() - world_size = get_mpi_context_world_size() - else: - local_rank = -1 - world_size = 1 - - if world_size > 1: - # mpi launch - logger.warning("mpirun launch, local_rank / world_size: %s : % s", local_rank, world_size) - - # TrainingArguments._setup_devices will call torch.distributed.init_process_group(backend="nccl") - # pytorch expects following environment settings (which would be set if launched with torch.distributed.launch). - - os.environ["RANK"] = str(local_rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = "29500" - - from onnxruntime.capi._pybind_state import set_cuda_device_id - - set_cuda_device_id(local_rank) - - test = ORTGlueTest() - test.setUp() - test.local_rank = local_rank - test.world_size = world_size - test.test_bert_with_mrpc() - else: - unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py b/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py deleted file mode 100644 index 92db204593bcd..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py +++ /dev/null @@ -1,281 +0,0 @@ -# adapted from run_multiple_choice.py of huggingface transformers -# https://github.com/huggingface/transformers/blob/master/examples/multiple-choice/run_multiple_choice.py - -import dataclasses # noqa: F401 -import logging -import os -import unittest -from dataclasses import dataclass, field -from typing import Dict, Optional - -import numpy as np -import torch # noqa: F401 -from numpy.testing import assert_allclose # noqa: F401 -from orttraining_run_glue import verify_old_and_new_api_are_equal # noqa: F401 -from orttraining_transformer_trainer import ORTTransformerTrainer -from transformers import HfArgumentParser # noqa: F401 -from transformers import Trainer # noqa: F401 -from transformers import ( - AutoConfig, - AutoModelForMultipleChoice, - AutoTokenizer, - EvalPrediction, - TrainingArguments, - set_seed, -) -from utils_multiple_choice import MultipleChoiceDataset, Split, SwagProcessor - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer # noqa: F401 - -logger = logging.getLogger(__name__) - - -def simple_accuracy(preds, labels): - return (preds == labels).mean() - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. - """ - - model_name_or_path: str = field(metadata={"help": "model identifier from huggingface.co/models"}) - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - cache_dir: Optional[str] = field( - default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} - ) - - -@dataclass -class DataTrainingArguments: - """ - Arguments pertaining to what data we are going to input our model for training and eval. - """ - - task_name: str = field(metadata={"help": "The name of the task to train on."}) - data_dir: str = field(metadata={"help": "Should contain the data files for the task."}) - max_seq_length: int = field( - default=128, - metadata={ - "help": "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded." - }, - ) - overwrite_cache: bool = field(default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}) - - -class ORTMultipleChoiceTest(unittest.TestCase): - def setUp(self): - # configurations not to be changed accoss tests - self.max_seq_length = 80 - self.train_batch_size = 16 - self.eval_batch_size = 2 - self.learning_rate = 2e-5 - self.num_train_epochs = 1.0 - self.local_rank = -1 - self.overwrite_output_dir = True - self.gradient_accumulation_steps = 8 - self.data_dir = "/bert_data/hf_data/swag/swagaf/data" - self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "multiple_choice_test_output/") - self.cache_dir = "/tmp/multiple_choice/" - self.logging_steps = 10 - self.rtol = 2e-01 - - def test_bert_with_swag(self): - expected_acc = 0.75 - expected_loss = 0.64 - - results = self.run_multiple_choice(model_name="bert-base-cased", task_name="swag", fp16=False) - assert results["acc"] >= expected_acc - assert results["loss"] <= expected_loss - - def test_bert_fp16_with_swag(self): - # larger batch can be handled with mixed precision - self.train_batch_size = 32 - - expected_acc = 0.73 - expected_loss = 0.68 - - results = self.run_multiple_choice(model_name="bert-base-cased", task_name="swag", fp16=True) - assert results["acc"] >= expected_acc - assert results["loss"] <= expected_loss - - def run_multiple_choice(self, model_name, task_name, fp16): - model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir) - data_args = DataTrainingArguments( - task_name=task_name, data_dir=self.data_dir, max_seq_length=self.max_seq_length - ) - - training_args = TrainingArguments( - output_dir=os.path.join(self.output_dir, task_name), - do_train=True, - do_eval=True, - per_gpu_train_batch_size=self.train_batch_size, - per_gpu_eval_batch_size=self.eval_batch_size, - learning_rate=self.learning_rate, - num_train_epochs=self.num_train_epochs, - local_rank=self.local_rank, - overwrite_output_dir=self.overwrite_output_dir, - gradient_accumulation_steps=self.gradient_accumulation_steps, - fp16=fp16, - logging_steps=self.logging_steps, - ) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, - ) - logger.warning( - "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", - training_args.local_rank, - training_args.device, - training_args.n_gpu, - bool(training_args.local_rank != -1), - training_args.fp16, - ) - logger.info("Training/evaluation parameters %s", training_args) - - set_seed(training_args.seed) - onnxruntime.set_seed(training_args.seed) - - try: - processor = SwagProcessor() - label_list = processor.get_labels() - num_labels = len(label_list) - except KeyError: - raise ValueError("Task not found: %s" % (data_args.task_name)) # noqa: B904 - - config = AutoConfig.from_pretrained( - model_args.config_name if model_args.config_name else model_args.model_name_or_path, - num_labels=num_labels, - finetuning_task=data_args.task_name, - cache_dir=model_args.cache_dir, - ) - - tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - ) - - model = AutoModelForMultipleChoice.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - ) - - # Get datasets - train_dataset = ( - MultipleChoiceDataset( - data_dir=data_args.data_dir, - tokenizer=tokenizer, - task=data_args.task_name, - processor=processor, - max_seq_length=data_args.max_seq_length, - overwrite_cache=data_args.overwrite_cache, - mode=Split.train, - ) - if training_args.do_train - else None - ) - eval_dataset = ( - MultipleChoiceDataset( - data_dir=data_args.data_dir, - tokenizer=tokenizer, - task=data_args.task_name, - processor=processor, - max_seq_length=data_args.max_seq_length, - overwrite_cache=data_args.overwrite_cache, - mode=Split.dev, - ) - if training_args.do_eval - else None - ) - - def compute_metrics(p: EvalPrediction) -> Dict: - preds = np.argmax(p.predictions, axis=1) - return {"acc": simple_accuracy(preds, p.label_ids)} - - if model_name.startswith("bert"): - model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "token_type_ids", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "labels", - ["batch", num_labels], - ), - ], - "outputs": [("loss", [], True), ("reshaped_logits", ["batch", num_labels])], - } - else: - model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "labels", - ["batch", num_labels], - ), - ], - "outputs": [("loss", [], True), ("reshaped_logits", ["batch", num_labels])], - } - - # Initialize the ORTTrainer within ORTTransformerTrainer - trainer = ORTTransformerTrainer( - model=model, - model_desc=model_desc, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - compute_metrics=compute_metrics, - ) - - # Training - if training_args.do_train: - trainer.train() - trainer.save_model() - - # Evaluation - results = {} - if training_args.do_eval and training_args.local_rank in [-1, 0]: - logger.info("*** Evaluate ***") - - result = trainer.evaluate() - - logger.info(f"***** Eval results {data_args.task_name} *****") - for key, value in result.items(): - logger.info(" %s = %s", key, value) - - results.update(result) - - return results - - -if __name__ == "__main__": - unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py b/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py deleted file mode 100644 index 71e6bb8e4d2f2..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py +++ /dev/null @@ -1,6 +0,0 @@ -from orttraining_test_layer_norm_transform import layer_norm_transform # noqa: F401 -from orttraining_test_model_transform import add_expand_shape, add_name, fix_transpose # noqa: F401 - - -def postprocess_model(model): - add_name(model) diff --git a/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py b/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py deleted file mode 100644 index 21372caaf6779..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# orttraining_test_checkpoint_storage.py - -import os -import pickle -import shutil - -import numpy as np -import pytest -import torch - -from onnxruntime.training import _checkpoint_storage - -# Helper functions - - -def _equals(a, b): - """Checks recursively if two dictionaries are equal""" - if isinstance(a, dict): - return all(not (key not in b or not _equals(a[key], b[key])) for key in a) - else: - if isinstance(a, bytes): - a = a.decode() - if isinstance(b, bytes): - b = b.decode() - are_equal = a == b - return are_equal if isinstance(are_equal, bool) else are_equal.all() - - return False - - -def _numpy_types(obj_value): - """Return a bool indicating whether or not the input obj_value is a numpy type object - - Recursively checks if the obj_value (could be a dictionary) is a numpy type object. - Exceptions are str and bytes. - - Returns true if object is numpy type, str, or bytes - False if any other type - """ - if not isinstance(obj_value, dict): - return isinstance(obj_value, (str, bytes)) or type(obj_value).__module__ == np.__name__ - - return all(_numpy_types(value) for _, value in obj_value.items()) - - -def _get_dict(separated_key): - """Create dummy dictionary with different datatypes - - Returns the tuple of the entire dummy dictionary created, key argument as a dictionary for _checkpoint_storage.load - function and the value for that key in the original dictionary - - For example the complete dictionary is represented by: - { - 'int1':1, - 'int2': 2, - 'int_list': [1,2,3,5,6], - 'dict1': { - 'np_array': np.arange(100), - 'dict2': {'int3': 3, 'int4': 4}, - 'str1': "onnxruntime" - }, - 'bool1': bool(True), - 'int5': 5, - 'float1': 2.345, - 'np_array_float': np.array([1.234, 2.345, 3.456]), - 'np_array_float_3_dim': np.array([[[1,2],[3,4]], [[5,6],[7,8]]]) - } - - if the input key is ['dict1', 'str1'], then the key argument returned is 'dict1/str1' - and the value corresponding to that is "onnxruntime" - - so, for the above example, the returned tuple is: - (original_dict, {'key': 'dict1/str1', "onnxruntime") - """ - test_dict = { - "int1": 1, - "int2": 2, - "int_list": [1, 2, 3, 5, 6], - "dict1": {"np_array": np.arange(100), "dict2": {"int3": 3, "int4": 4}, "str1": "onnxruntime"}, - "bool1": True, - "int5": 5, - "float1": 2.345, - "np_array_float": np.array([1.234, 2.345, 3.456]), - "np_array_float_3_dim": np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]), - } - key = "" - expected_val = test_dict - for single_key in separated_key: - key += single_key + "/" - expected_val = expected_val[single_key] - return test_dict, {"key": key} if len(separated_key) > 0 else dict(), expected_val - - -class _CustomClass: - """Custom object that encpsulates dummy values for loss, epoch and train_step""" - - def __init__(self): - self._loss = 1.23 - self._epoch = 12000 - self._train_step = 25 - - def __eq__(self, other): - if isinstance(other, _CustomClass): - return self._loss == other._loss and self._epoch == other._epoch and self._train_step == other._train_step - - -# Test fixtures - - -@pytest.yield_fixture(scope="function") -def checkpoint_storage_test_setup(): - checkpoint_dir = os.path.abspath("checkpoint_dir/") - if not os.path.exists(checkpoint_dir): - os.makedirs(checkpoint_dir, exist_ok=True) - pytest.checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.ortcp") - yield "checkpoint_storage_test_setup" - shutil.rmtree(checkpoint_dir) - - -@pytest.yield_fixture(scope="function") -def checkpoint_storage_test_parameterized_setup(request, checkpoint_storage_test_setup): - yield request.param - - -# Tests - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", - [ - _get_dict([]), - _get_dict(["int1"]), - _get_dict(["dict1"]), - _get_dict(["dict1", "dict2"]), - _get_dict(["dict1", "dict2", "int4"]), - _get_dict(["dict1", "str1"]), - _get_dict(["bool1"]), - _get_dict(["float1"]), - _get_dict(["np_array_float"]), - ], - indirect=True, -) -def test_checkpoint_storage_saved_dict_matches_loaded(checkpoint_storage_test_parameterized_setup): - to_save = checkpoint_storage_test_parameterized_setup[0] - key_arg = checkpoint_storage_test_parameterized_setup[1] - expected = checkpoint_storage_test_parameterized_setup[2] - _checkpoint_storage.save(to_save, pytest.checkpoint_path) - loaded = _checkpoint_storage.load(pytest.checkpoint_path, **key_arg) - assert _equals(loaded, expected) - assert _numpy_types(loaded) - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", - [{"int_set": {1, 2, 3, 4, 5}}, {"str_set": {"one", "two"}}, [1, 2, 3], 2.352], - indirect=True, -) -def test_checkpoint_storage_saving_non_supported_types_fails(checkpoint_storage_test_parameterized_setup): - to_save = checkpoint_storage_test_parameterized_setup - with pytest.raises(Exception): # noqa: B017 - _checkpoint_storage.save(to_save, pytest.checkpoint_path) - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", - [ - ({"int64_tensor": torch.tensor(np.arange(100))}, "int64_tensor", torch.int64, np.int64), - ({"int32_tensor": torch.tensor(np.arange(100), dtype=torch.int32)}, "int32_tensor", torch.int32, np.int32), - ({"int16_tensor": torch.tensor(np.arange(100), dtype=torch.int16)}, "int16_tensor", torch.int16, np.int16), - ({"int8_tensor": torch.tensor(np.arange(100), dtype=torch.int8)}, "int8_tensor", torch.int8, np.int8), - ({"float64_tensor": torch.tensor(np.array([1.0, 2.0]))}, "float64_tensor", torch.float64, np.float64), - ( - {"float32_tensor": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32)}, - "float32_tensor", - torch.float32, - np.float32, - ), - ( - {"float16_tensor": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float16)}, - "float16_tensor", - torch.float16, - np.float16, - ), - ], - indirect=True, -) -def test_checkpoint_storage_saving_tensor_datatype(checkpoint_storage_test_parameterized_setup): - tensor_dict = checkpoint_storage_test_parameterized_setup[0] - tensor_name = checkpoint_storage_test_parameterized_setup[1] - tensor_dtype = checkpoint_storage_test_parameterized_setup[2] - np_dtype = checkpoint_storage_test_parameterized_setup[3] - - _checkpoint_storage.save(tensor_dict, pytest.checkpoint_path) - - loaded = _checkpoint_storage.load(pytest.checkpoint_path) - assert isinstance(loaded[tensor_name], np.ndarray) - assert tensor_dict[tensor_name].dtype == tensor_dtype - assert loaded[tensor_name].dtype == np_dtype - assert (tensor_dict[tensor_name].numpy() == loaded[tensor_name]).all() - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", - [ - ({"two_dim": torch.ones([2, 4], dtype=torch.float64)}, "two_dim"), - ({"three_dim": torch.ones([2, 4, 6], dtype=torch.float64)}, "three_dim"), - ({"four_dim": torch.ones([2, 4, 6, 8], dtype=torch.float64)}, "four_dim"), - ], - indirect=True, -) -def test_checkpoint_storage_saving_multiple_dimension_tensors(checkpoint_storage_test_parameterized_setup): - tensor_dict = checkpoint_storage_test_parameterized_setup[0] - tensor_name = checkpoint_storage_test_parameterized_setup[1] - - _checkpoint_storage.save(tensor_dict, pytest.checkpoint_path) - - loaded = _checkpoint_storage.load(pytest.checkpoint_path) - assert isinstance(loaded[tensor_name], np.ndarray) - assert (tensor_dict[tensor_name].numpy() == loaded[tensor_name]).all() - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", [{}, {"a": {}}, {"a": {"b": {}}}], indirect=True -) -def test_checkpoint_storage_saving_and_loading_empty_dictionaries_succeeds(checkpoint_storage_test_parameterized_setup): - saved = checkpoint_storage_test_parameterized_setup - _checkpoint_storage.save(saved, pytest.checkpoint_path) - - loaded = _checkpoint_storage.load(pytest.checkpoint_path) - assert _equals(saved, loaded) - - -def test_checkpoint_storage_load_file_that_does_not_exist_fails(checkpoint_storage_test_setup): - with pytest.raises(Exception): # noqa: B017 - _checkpoint_storage.load(pytest.checkpoint_path) - - -def test_checkpoint_storage_for_custom_user_dict_succeeds(checkpoint_storage_test_setup): - custom_class = _CustomClass() - user_dict = {"tensor1": torch.tensor(np.arange(100), dtype=torch.float32), "custom_class": custom_class} - - pickled_bytes = pickle.dumps(user_dict).hex() - to_save = {"a": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32), "user_dict": pickled_bytes} - _checkpoint_storage.save(to_save, pytest.checkpoint_path) - - loaded_dict = _checkpoint_storage.load(pytest.checkpoint_path) - assert (loaded_dict["a"] == to_save["a"].numpy()).all() - try: # noqa: SIM105 - loaded_dict["user_dict"] = loaded_dict["user_dict"].decode() - except AttributeError: - pass - loaded_obj = pickle.loads(bytes.fromhex(loaded_dict["user_dict"])) - - assert torch.all(loaded_obj["tensor1"].eq(user_dict["tensor1"])) - assert loaded_obj["custom_class"] == custom_class diff --git a/orttraining/orttraining/test/python/orttraining_test_data_loader.py b/orttraining/orttraining/test/python/orttraining_test_data_loader.py index aa15b44ae0d66..0009d2d3d7e1b 100644 --- a/orttraining/orttraining/test/python/orttraining_test_data_loader.py +++ b/orttraining/orttraining/test/python/orttraining_test_data_loader.py @@ -4,8 +4,6 @@ import torch from torch.utils.data import DataLoader, Dataset -from onnxruntime.capi.ort_trainer import generate_sample - global_rng = random.Random() @@ -41,6 +39,16 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None): return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() +def generate_sample(desc, device=None): + """Generate a sample based on the description""" + # symbolic dimensions are described with strings. set symbolic dimensions to be 1 + size = [s if isinstance(s, (int)) else 1 for s in desc.shape_] + if desc.num_classes_: + return torch.randint(0, desc.num_classes_, size, dtype=desc.dtype_).to(device) + else: + return torch.randn(size, dtype=desc.dtype_).to(device) + + class OrtTestDataset(Dataset): def __init__(self, input_desc, seq_len, dataset_len, device): import copy diff --git a/orttraining/orttraining/test/python/orttraining_test_debuggability.py b/orttraining/orttraining/test/python/orttraining_test_debuggability.py deleted file mode 100644 index 499f0ba7a1ff5..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_debuggability.py +++ /dev/null @@ -1,40 +0,0 @@ -import pytest -import torch -from _test_commons import _load_pytorch_transformer_model - -from onnxruntime import set_seed -from onnxruntime.training import optim, orttrainer - -############################################################################### -# Testing starts here ######################################################### -############################################################################### - - -@pytest.mark.parametrize( - "seed, device", - [ - (24, "cuda"), - ], -) -def testORTTransformerModelExport(seed, device): - # Common setup - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": { - "check_model_export": True, - }, - "device": { - "id": device, - }, - } - ) - - # Setup for the first ORTTRainer run - torch.manual_seed(seed) - set_seed(seed) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) - first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - data, targets = batcher_fn(train_data, 0) - _ = first_trainer.train_step(data, targets) - assert first_trainer._onnx_model is not None diff --git a/orttraining/orttraining/test/python/orttraining_test_dort.py b/orttraining/orttraining/test/python/orttraining_test_dort.py index 88d9c00984d3e..2a7012787be6e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_dort.py +++ b/orttraining/orttraining/test/python/orttraining_test_dort.py @@ -19,6 +19,7 @@ class TestTorchDynamoOrt(unittest.TestCase): def setUp(self): # Make computation deterministic. torch.manual_seed(42) + print(f"TestTorchDynamoOrt uses PyTorch version {torch.__version__}") def test_elementwise_model(self): torch._dynamo.reset() diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis.py index 506aafbe9f618..a3e666dd404f2 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis.py @@ -27,7 +27,7 @@ def run_training_apis_python_api_tests(cwd, log): log.debug("Running: ort training api tests") - command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_python_bindings.py"] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ort_apis_py_bindings.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() @@ -37,7 +37,7 @@ def run_onnxblock_tests(cwd, log): log.debug("Running: onnxblock tests") - command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_onnxblock.py"] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ort_apis_onnxblock.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() diff --git a/orttraining/orttraining/test/python/orttraining_test_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py similarity index 100% rename from orttraining/orttraining/test/python/orttraining_test_onnxblock.py rename to orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py diff --git a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py similarity index 99% rename from orttraining/orttraining/test/python/orttraining_test_python_bindings.py rename to orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py index d5c37b3e36ee7..34d8c24ccfab4 100644 --- a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py @@ -11,7 +11,7 @@ import onnx import pytest import torch -from orttraining_test_onnxblock import _get_models +from orttraining_test_ort_apis_onnxblock import _get_models import onnxruntime.training.onnxblock as onnxblock from onnxruntime import OrtValue, SessionOptions diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 13024b81f4b3c..ad0e5d8beba3d 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -5761,6 +5761,7 @@ def run_step(model, input, positions): ("MatMul", 1), ("Dropout", 0), ("LayerNormalization", 0), + ("LayerNormalization", 1), ("Cast", 0), ("BiasGelu", 0), ("Gelu", 0), @@ -5773,12 +5774,18 @@ def test_ops_for_padding_elimination(test_cases): test_op = test_cases[0] case = test_cases[1] + vocab_size, hidden_size = 50265, 768 + batch_size, max_seq_length = 8, 128 + class ToyModel(torch.nn.Module): def __init__(self, vocab_size, hidden_size, pad_token_id): super().__init__() self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id) if test_op == "LayerNormalization": - self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-05) + if case == 0: + self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-05) + else: + self.LayerNorm = nn.LayerNorm([max_seq_length, hidden_size], eps=1e-05) self.hidden_size = hidden_size # test test_elementwise op for padding elimination @@ -5889,8 +5896,6 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): batched_inputs.append(torch.cat((input_id, padding))) return torch.stack(batched_inputs) - vocab_size, hidden_size = 50265, 768 - batch_size, max_seq_length = 8, 128 device = "cuda" model = ORTModule(ToyModel(vocab_size, hidden_size, 1).to(device)) x = generate_inputs(batch_size, max_seq_length, vocab_size) @@ -5908,7 +5913,7 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 3 else: assert len([node.op_type for node in training_model.graph.node if node.op_type == "FlattenAndUnpad"]) == 2 - gathergrad_node = next(node for node in training_model.graph.node if node.op_type == "PadAndUnflatten") + recover_pad_node = next(node for node in training_model.graph.node if node.op_type == "PadAndUnflatten") def find_input_node_type(model, arg): result = [] @@ -5917,14 +5922,14 @@ def find_input_node_type(model, arg): result.append(node) return result[0].op_type if len(result) == 1 else None - gathergrad_input_optypes = [find_input_node_type(training_model, arg) for arg in gathergrad_node.input] + recover_pad_input_optypes = [find_input_node_type(training_model, arg) for arg in recover_pad_node.input] if test_op == "Add" or test_op == "Mul" or test_op == "Sub": - assert test_op in gathergrad_input_optypes + assert test_op in recover_pad_input_optypes else: if case == 0: - assert test_op in gathergrad_input_optypes + assert test_op in recover_pad_input_optypes else: - assert "ATen" in gathergrad_input_optypes + assert "ATen" in recover_pad_input_optypes del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] diff --git a/orttraining/orttraining/test/python/orttraining_test_hooks.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_hooks.py similarity index 100% rename from orttraining/orttraining/test/python/orttraining_test_hooks.py rename to orttraining/orttraining/test/python/orttraining_test_ortmodule_hooks.py diff --git a/orttraining/orttraining/test/python/orttraining_test_onnx_ops_ortmodule.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py similarity index 100% rename from orttraining/orttraining/test/python/orttraining_test_onnx_ops_ortmodule.py rename to orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py deleted file mode 100644 index 45b87b32f7d64..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py +++ /dev/null @@ -1,1283 +0,0 @@ -import copy # noqa: F401 -import inspect # noqa: F401 -import math # noqa: F401 -import os -from functools import partial - -import _test_commons -import _test_helpers -import onnx -import pytest -import torch -from numpy.testing import assert_allclose - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription -from onnxruntime.capi.ort_trainer import LossScaler as Legacy_LossScaler -from onnxruntime.capi.ort_trainer import ModelDescription as Legacy_ModelDescription -from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer -from onnxruntime.training import amp, optim, orttrainer - -############################################################################### -# Helper functions ############################################################ -############################################################################### - - -def generate_random_input_from_model_desc(desc, seed=1, device="cuda:0"): - """Generates a sample input for the BERT model using the model desc""" - - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - dtype = torch.int64 - vocab_size = 30528 - num_classes = [vocab_size, 2, 2, vocab_size, 2] - dims = {"batch_size": 16, "seq_len": 1} - sample_input = [] - for index, input in enumerate(desc["inputs"]): - size = [] - for s in input[1]: - if isinstance(s, (int)): - size.append(s) - else: - size.append(dims[s] if s in dims else 1) - sample_input.append(torch.randint(0, num_classes[index], tuple(size), dtype=dtype).to(device)) - return sample_input - - -# EXPERIMENTAL HELPER FUNCTIONS - - -def bert_model_description(dynamic_shape=True): - """Creates the model description dictionary with static dimensions""" - - if dynamic_shape: - model_desc = { - "inputs": [ - ("input_ids", ["batch_size", "seq_len"]), - ( - "segment_ids", - ["batch_size", "seq_len"], - ), - ( - "input_mask", - ["batch_size", "seq_len"], - ), - ( - "masked_lm_labels", - ["batch_size", "seq_len"], - ), - ( - "next_sentence_labels", - [ - "batch_size", - ], - ), - ], - "outputs": [("loss", [], True)], - } - else: - batch_size = 16 - seq_len = 1 - model_desc = { - "inputs": [ - ("input_ids", [batch_size, seq_len]), - ( - "segment_ids", - [batch_size, seq_len], - ), - ( - "input_mask", - [batch_size, seq_len], - ), - ( - "masked_lm_labels", - [batch_size, seq_len], - ), - ( - "next_sentence_labels", - [ - batch_size, - ], - ), - ], - "outputs": [("loss", [], True)], - } - return model_desc - - -def optimizer_parameters(model): - """A method to assign different hyper parameters for different model parameter groups""" - - no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] - no_decay_param_group = [] - for initializer in model.graph.initializer: - if any(key in initializer.name for key in no_decay_keys): - no_decay_param_group.append(initializer.name) - params = [ - { - "params": no_decay_param_group, - "alpha": 0.9, - "beta": 0.999, - "lambda_coef": 0.0, - "epsilon": 1e-6, - "do_bias_correction": False, - } - ] - - return params - - -def load_bert_onnx_model(): - bert_onnx_model_path = os.path.join("testdata", "bert_toy_postprocessed.onnx") - model = onnx.load(bert_onnx_model_path) - return model - - -class CustomLossScaler(amp.LossScaler): - def __init__(self, loss_scale=float(1 << 16)): - super().__init__(loss_scale) - self._initial_loss_scale = loss_scale - self.loss_scale = loss_scale - - def reset(self): - self.loss_scale = self._initial_loss_scale - - def update(self, train_step_info): - self.loss_scale *= 0.9 - return self.loss_scale - - -# LEGACY HELPER FUNCTIONS - - -class LegacyCustomLossScaler: - def __init__(self, loss_scale=float(1 << 16)): - self._initial_loss_scale = loss_scale - self.loss_scale_ = loss_scale - - def reset(self): - self.loss_scale_ = self._initial_loss_scale - - def update_loss_scale(self, is_all_finite): - self.loss_scale_ *= 0.9 - - -def legacy_model_params(lr, device=torch.device("cuda", 0)): # noqa: B008 - legacy_model_desc = legacy_bert_model_description() - learning_rate_description = legacy_ort_trainer_learning_rate_description() - learning_rate = torch.tensor([lr]).to(device) - return (legacy_model_desc, learning_rate_description, learning_rate) - - -def legacy_ort_trainer_learning_rate_description(): - return Legacy_IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ) - - -def legacy_bert_model_description(): - input_ids_desc = Legacy_IODescription("input_ids", ["batch", "max_seq_len_in_batch"]) - segment_ids_desc = Legacy_IODescription("segment_ids", ["batch", "max_seq_len_in_batch"]) - input_mask_desc = Legacy_IODescription("input_mask", ["batch", "max_seq_len_in_batch"]) - masked_lm_labels_desc = Legacy_IODescription("masked_lm_labels", ["batch", "max_seq_len_in_batch"]) - next_sentence_labels_desc = Legacy_IODescription( - "next_sentence_labels", - [ - "batch", - ], - ) - loss_desc = Legacy_IODescription("loss", []) - - return Legacy_ModelDescription( - [input_ids_desc, segment_ids_desc, input_mask_desc, masked_lm_labels_desc, next_sentence_labels_desc], - [loss_desc], - ) - - -def legacy_optim_params_a(name): - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6, "do_bias_correction": False} - - -def legacy_optim_params_b(name): - params = ["bert.embeddings.LayerNorm.bias", "bert.embeddings.LayerNorm.weight"] - if name in params: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6, "do_bias_correction": False} - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6, "do_bias_correction": False} - - -def legacy_optim_params_c(name): - params_group = optimizer_parameters(load_bert_onnx_model()) - if name in params_group[0]["params"]: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6, "do_bias_correction": False} - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6, "do_bias_correction": False} - - -############################################################################### -# Testing starts here ######################################################### -############################################################################### - - -@pytest.mark.parametrize("dynamic_shape", [(True), (False)]) -def testToyBERTModelBasicTraining(dynamic_shape): - model_desc = bert_model_description(dynamic_shape) - model = load_bert_onnx_model() - - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions({}) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - for _i in range(10): - sample_input = generate_random_input_from_model_desc(model_desc) - output = trainer.train_step(*sample_input) - assert output.shape == torch.Size([]) - - -@pytest.mark.parametrize( - "expected_losses", - [([11.041123, 10.986166, 11.101636, 11.013366, 11.03775, 11.041175, 10.957118, 11.069563, 11.040824, 11.16437])], -) -def testToyBERTDeterministicCheck(expected_losses): - # Common setup - train_steps = 10 - device = "cuda" - seed = 1 - rtol = 1e-3 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # Modeling - model_desc = bert_model_description() - model = load_bert_onnx_model() - optimizer_parameters(model) - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train - experimental_losses = [] - for i in range(train_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - - # Check output - _test_helpers.assert_model_outputs(experimental_losses, expected_losses, rtol=rtol) - - -@pytest.mark.parametrize( - "initial_lr, lr_scheduler, expected_learning_rates, expected_losses", - [ - ( - 1.0, - optim.lr_scheduler.ConstantWarmupLRScheduler, - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - [ - 10.988012313842773, - 10.99213981628418, - 120.79301452636719, - 36.11647033691406, - 95.83200073242188, - 221.2766571044922, - 208.40316772460938, - 279.5332946777344, - 402.46380615234375, - 325.79254150390625, - ], - ), - ( - 0.5, - optim.lr_scheduler.ConstantWarmupLRScheduler, - [0.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], - [ - 10.988012313842773, - 10.99213981628418, - 52.69743347167969, - 19.741533279418945, - 83.88340759277344, - 126.39848327636719, - 91.53898620605469, - 63.62016296386719, - 102.21206665039062, - 180.1424560546875, - ], - ), - ( - 1.0, - optim.lr_scheduler.CosineWarmupLRScheduler, - [ - 0.0, - 0.9931806517013612, - 0.9397368756032445, - 0.8386407858128706, - 0.7008477123264848, - 0.5412896727361662, - 0.37725725642960045, - 0.22652592093878665, - 0.10542974530180327, - 0.02709137914968268, - ], - [ - 10.988012313842773, - 10.99213981628418, - 120.6441650390625, - 32.152557373046875, - 89.63705444335938, - 138.8782196044922, - 117.57748413085938, - 148.01927185058594, - 229.60403442382812, - 110.2930908203125, - ], - ), - ( - 1.0, - optim.lr_scheduler.LinearWarmupLRScheduler, - [ - 0.0, - 0.9473684210526315, - 0.8421052631578947, - 0.7368421052631579, - 0.631578947368421, - 0.5263157894736842, - 0.42105263157894735, - 0.3157894736842105, - 0.21052631578947367, - 0.10526315789473684, - ], - [ - 10.988012313842773, - 10.99213981628418, - 112.89633178710938, - 31.114538192749023, - 80.94029235839844, - 131.34490966796875, - 111.4329605102539, - 133.74252319335938, - 219.37344360351562, - 109.67041015625, - ], - ), - ( - 1.0, - optim.lr_scheduler.PolyWarmupLRScheduler, - [ - 0.0, - 0.9473684263157895, - 0.8421052789473684, - 0.7368421315789474, - 0.6315789842105263, - 0.5263158368421054, - 0.42105268947368424, - 0.31578954210526317, - 0.21052639473684212, - 0.10526324736842106, - ], - [ - 10.988012313842773, - 10.99213981628418, - 112.89633178710938, - 31.114538192749023, - 80.9402847290039, - 131.3447265625, - 111.43253326416016, - 133.7415008544922, - 219.37147521972656, - 109.66986083984375, - ], - ), - ], -) -def testToyBERTModelLRScheduler(initial_lr, lr_scheduler, expected_learning_rates, expected_losses): - return # TODO: re-enable after nondeterminism on backend is fixed - # Common setup - device = "cuda" - total_steps = 10 - seed = 1 - warmup = 0.05 - cycles = 0.5 - power = 1.0 - lr_end = 1e-7 - rtol = 1e-3 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # Setup LR Schedulers - if ( - lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler - or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler - ): - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) - elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) - elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) - else: - raise RuntimeError("Invalid lr_scheduler") - - # Modeling - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.AdamConfig(lr=initial_lr) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "lr_scheduler": lr_scheduler, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train - losses = [] - learning_rates = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - losses.append(trainer.train_step(*sample_input).cpu().item()) - learning_rates.append(trainer.options.lr_scheduler.get_last_lr()[0]) - - # Check output - _test_helpers.assert_model_outputs(learning_rates, expected_learning_rates, rtol=rtol) - _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol) - - -@pytest.mark.parametrize( - "loss_scaler, expected_losses", - [ - ( - None, - [ - 11.041126, - 10.986309, - 11.101673, - 11.013394, - 11.037781, - 11.041253, - 10.957072, - 11.069506, - 11.040807, - 11.164349, - ], - ), - ( - amp.DynamicLossScaler(), - [ - 11.041126, - 10.986309, - 11.101673, - 11.013394, - 11.037781, - 11.041253, - 10.957072, - 11.069506, - 11.040807, - 11.164349, - ], - ), - ( - CustomLossScaler(), - [ - 11.041126, - 10.986309, - 11.101645, - 11.013412, - 11.037757, - 11.041273, - 10.957077, - 11.069525, - 11.040765, - 11.164298, - ], - ), - ], -) -def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses): - # Common setup - total_steps = 10 - device = "cuda" - seed = 1 - rtol = 1e-3 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # Modeling - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train - losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - losses.append(trainer.train_step(*sample_input).cpu().item()) - - # Check output - _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol) - - -@pytest.mark.parametrize( - "gradient_accumulation_steps, expected_losses", - [ - ( - 1, - [ - 11.041123, - 10.986166, - 11.101636, - 11.013366, - 11.03775, - 11.041175, - 10.957118, - 11.069563, - 11.040824, - 11.16437, - ], - ), - ( - 4, - [ - 11.041123, - 10.982856, - 11.105512, - 11.006721, - 11.03358, - 11.05058, - 10.955864, - 11.059035, - 11.037753, - 11.162649, - ], - ), - ( - 7, - [ - 11.041123, - 10.982856, - 11.105512, - 11.006721, - 11.036314, - 11.055109, - 10.960751, - 11.05809, - 11.038856, - 11.159635, - ], - ), - ], -) -def testToyBERTModelGradientAccumulation(gradient_accumulation_steps, expected_losses): - # Common setup - total_steps = 10 - device = "cuda" - seed = 1 - rtol = 1e-3 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # Modeling - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train - losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - losses.append(trainer.train_step(*sample_input).cpu().item()) - - # Check output - _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol) - - -def testToyBertCheckpointBasic(): - # Common setup - seed = 1 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions({"debug": {"deterministic_compute": True}}) - - # Create ORTTrainer and save initial state in a dict - model = load_bert_onnx_model() - model_desc = bert_model_description() - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - sd = trainer.state_dict() - - ## All initializers must be present in the state_dict - ## when the specified model for ORTTRainer is an ONNX model - for param in trainer._onnx_model.graph.initializer: - assert param.name in sd["model"]["full_precision"] - - ## Modify one of the state values and load into ORTTrainer - sd["model"]["full_precision"]["bert.encoder.layer.0.attention.output.LayerNorm.weight"] += 10 - trainer.load_state_dict(sd) - - ## Save a checkpoint - ckpt_dir = "testdata" - trainer.save_checkpoint(os.path.join(ckpt_dir, "bert_toy_save_test.ortcp")) - del trainer - del model - - # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer - model2 = load_bert_onnx_model() - model_desc2 = bert_model_description() - trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config, options=opts) - trainer2.load_checkpoint(os.path.join(ckpt_dir, "bert_toy_save_test.ortcp")) - loaded_sd = trainer2.state_dict() - - # Assert whether original state and the one loaded from checkpoint matches - _test_commons.assert_all_states_close_ort(sd, loaded_sd) - - -def testToyBertCheckpointFrozenWeights(): - # Common setup - seed = 1 - total_steps = 10 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "utils": {"frozen_weights": ["bert.encoder.layer.0.attention.self.value.weight"]}, - } - ) - - # Create ORTTrainer and save initial state in a dict - model = load_bert_onnx_model() - model_desc = bert_model_description() - optim_config = optim.LambConfig() - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train for a few steps - for _i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, seed) - _ = trainer.train_step(*sample_input) - sample_input = generate_random_input_from_model_desc(model_desc, seed + total_steps + 1) - # Evaluate once to get a base loss - loss = trainer.eval_step(*sample_input) - # Save checkpoint - state_dict = trainer.state_dict() - - # Load previous state into another instance of ORTTrainer - model2 = load_bert_onnx_model() - model_desc2 = bert_model_description() - optim_config2 = optim.LambConfig() - trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config2, options=opts) - trainer2.load_state_dict(state_dict) - # Evaluate once to get a base loss - ckpt_loss = trainer2.eval_step(*sample_input) - - # Must match as both trainers have the same dict state - assert_allclose(loss.cpu(), ckpt_loss.cpu()) - loaded_state_dict = trainer2.state_dict() - _test_commons.assert_all_states_close_ort(state_dict, loaded_state_dict) - - -@pytest.mark.parametrize( - "optimizer, mixedprecision_enabled", - [ - (optim.LambConfig(), False), - (optim.AdamConfig(), False), - (optim.LambConfig(), True), - (optim.AdamConfig(), True), - ], -) -def testToyBertLoadOptimState(optimizer, mixedprecision_enabled): - # Common setup - device = "cuda" - seed = 1 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - optim_config = optimizer - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": {"id": device}, - "mixed_precision": { - "enabled": mixedprecision_enabled, - }, - "distributed": {"allreduce_post_accumulation": True}, - } - ) - - # Create ORTTrainer and save initial state in a dict - model = load_bert_onnx_model() - model_desc = bert_model_description() - dummy_init_state = _test_commons.generate_dummy_optim_state(model, optimizer) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - trainer.load_state_dict(dummy_init_state) - - # Expected values - input_ids = torch.tensor( - [ - [26598], - [21379], - [19922], - [5219], - [5644], - [20559], - [23777], - [25672], - [22969], - [16824], - [16822], - [635], - [27399], - [20647], - [18519], - [15546], - ], - device=device, - ) - segment_ids = torch.tensor( - [[0], [1], [0], [1], [0], [0], [1], [0], [0], [1], [1], [0], [0], [1], [1], [1]], device=device - ) - input_mask = torch.tensor( - [[0], [0], [0], [0], [1], [1], [1], [0], [1], [1], [0], [0], [0], [1], [0], [0]], device=device - ) - masked_lm_labels = torch.tensor( - [ - [25496], - [16184], - [11005], - [16228], - [14884], - [21660], - [8678], - [23083], - [4027], - [8397], - [11921], - [1333], - [26482], - [1666], - [17925], - [27978], - ], - device=device, - ) - next_sentence_labels = torch.tensor([0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device) - - # Actual values - _ = trainer.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels) - - actual_state_dict = trainer.state_dict() - del actual_state_dict["model"] - _test_commons.assert_all_states_close_ort(actual_state_dict, dummy_init_state) - - -@pytest.mark.parametrize( - "model_params", - [ - (["bert.embeddings.LayerNorm.bias"]), - ( - [ - "bert.embeddings.LayerNorm.bias", - "bert.embeddings.LayerNorm.weight", - "bert.encoder.layer.0.attention.output.LayerNorm.bias", - ] - ), - ], -) -def testORTTrainerFrozenWeights(model_params): - device = "cuda" - total_steps = 10 - seed = 1 - - # EXPERIMENTAL API - model_desc = bert_model_description() - model = load_bert_onnx_model() - - optim_config = optim.LambConfig() - # Setup ORTTrainer WITHOUT frozen weights - opts_dict = { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - opts = orttrainer.ORTTrainerOptions(opts_dict) - - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - trainer.train_step(*sample_input) - - # All model_params must be in the session state - assert trainer._onnx_model is not None - session_state = trainer._training_session.get_state() - assert all([param in session_state for param in model_params]) - - # Setup ORTTrainer WITH frozen weights - opts_dict.update({"utils": {"frozen_weights": model_params}}) - opts = orttrainer.ORTTrainerOptions(opts_dict) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - trainer.train_step(*sample_input) - - # All model_params CANNOT be in the session state - assert trainer._onnx_model is not None - session_state = trainer._training_session.get_state() - assert not any([param in session_state for param in model_params]) - - -def testToyBERTSaveAsONNX(): - device = "cuda" - onnx_file_name = "_____temp_toy_bert_onnx_model.onnx" - if os.path.exists(onnx_file_name): - os.remove(onnx_file_name) - assert not os.path.exists(onnx_file_name) - - # Load trainer - model_desc = bert_model_description() - model = load_bert_onnx_model() - - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - ) - - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - trainer.save_as_onnx(onnx_file_name) - assert os.path.exists(onnx_file_name) - - with open(onnx_file_name, "rb") as f: - bin_str = f.read() - reload_onnx_model = onnx.load_model_from_string(bin_str) - os.remove(onnx_file_name) - - # Create a new trainer from persisted ONNX model and compare with original ONNX model - trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config, options=opts) - assert trainer_from_onnx._onnx_model is not None - assert id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model) - for initializer, loaded_initializer in zip( - trainer._onnx_model.graph.initializer, trainer_from_onnx._onnx_model.graph.initializer - ): - assert initializer.name == loaded_initializer.name - assert onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph( - trainer._onnx_model.graph - ) - _test_helpers.assert_onnx_weights(trainer, trainer_from_onnx) - - -############################################################################### -# Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############ -############################################################################### -@pytest.mark.parametrize( - "optimizer_config", - [ - (optim.AdamConfig), - # (optim.LambConfig), # TODO: re-enable after nondeterminism on backend is fixed - (optim.SGDConfig), - ], -) -def testToyBERTModelLegacyExperimentalBasicTraining(optimizer_config): - # Common setup - train_steps = 512 - - device = "cuda" - seed = 1 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # EXPERIMENTAL API - model_desc = bert_model_description() - model = load_bert_onnx_model() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - ) - optim_config = optimizer_config(lr=0.01) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - experimental_losses = [] - for i in range(train_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - if optimizer_config == optim.AdamConfig: - legacy_optimizer = "AdamOptimizer" - elif optimizer_config == optim.LambConfig: - legacy_optimizer = "LambOptimizer" - elif optimizer_config == optim.SGDConfig: - legacy_optimizer = "SGDOptimizer" - else: - raise RuntimeError("Invalid optimizer_config") - - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(lr=optim_config.lr) - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - legacy_optimizer, - None, - learning_rate_description, - device, - _use_deterministic_compute=True, - ) - legacy_losses = [] - for i in range(train_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) - legacy_losses.append(leg_loss.cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses, True) - - -@pytest.mark.parametrize( - "initial_lr, lr_scheduler, legacy_lr_scheduler", - [ - (1.0, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler), - (0.5, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler), - (1.0, optim.lr_scheduler.CosineWarmupLRScheduler, _test_commons.legacy_cosine_lr_scheduler), - (1.0, optim.lr_scheduler.LinearWarmupLRScheduler, _test_commons.legacy_linear_lr_scheduler), - (1.0, optim.lr_scheduler.PolyWarmupLRScheduler, _test_commons.legacy_poly_lr_scheduler), - ], -) -def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, legacy_lr_scheduler): - ############################################################################ - # These tests require hard-coded values for 'total_steps' and 'initial_lr' # - ############################################################################ - - # Common setup - total_steps = 128 - device = "cuda" - seed = 1 - warmup = 0.05 - cycles = 0.5 - power = 1.0 - lr_end = 1e-7 - - # Setup both Experimental and Legacy LR Schedulers before the experimental loop - if ( - legacy_lr_scheduler == _test_commons.legacy_constant_lr_scheduler - or legacy_lr_scheduler == _test_commons.legacy_linear_lr_scheduler - ): - legacy_lr_scheduler = partial( - legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup - ) - elif legacy_lr_scheduler == _test_commons.legacy_cosine_lr_scheduler: - legacy_lr_scheduler = partial( - legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup, cycles=cycles - ) - elif legacy_lr_scheduler == _test_commons.legacy_poly_lr_scheduler: - legacy_lr_scheduler = partial( - legacy_lr_scheduler, - initial_lr=initial_lr, - total_steps=total_steps, - warmup=warmup, - power=power, - lr_end=lr_end, - ) - else: - raise RuntimeError("Invalid legacy_lr_scheduler") - if ( - lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler - or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler - ): - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) - elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) - elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) - else: - raise RuntimeError("Invalid lr_scheduler") - - # EXPERIMENTAL API - model_desc = bert_model_description() - model = load_bert_onnx_model() - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - optim_config = optim.AdamConfig(lr=initial_lr) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "lr_scheduler": lr_scheduler, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - experimental_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - assert_allclose(trainer.options.lr_scheduler.get_last_lr()[0], legacy_lr_scheduler(i)) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(initial_lr) - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - "AdamOptimizer", - None, - learning_rate_description, - device, - _use_deterministic_compute=True, - get_lr_this_step=legacy_lr_scheduler, - ) - legacy_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - leg_loss = legacy_trainer.train_step(*sample_input) - legacy_losses.append(leg_loss.cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) - - -@pytest.mark.parametrize( - "loss_scaler, legacy_loss_scaler", - [ - (None, Legacy_LossScaler("ort_test_input_loss_scaler", True)), - (amp.DynamicLossScaler(), Legacy_LossScaler("ort_test_input_loss_scaler", True)), - (CustomLossScaler(), LegacyCustomLossScaler()), - ], -) -def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental(loss_scaler, legacy_loss_scaler): - # Common setup - total_steps = 128 - device = "cuda" - seed = 1 - - # EXPERIMENTAL IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.AdamConfig(lr=0.001) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - experimental_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(optim_config.lr) - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - "AdamOptimizer", - None, - learning_rate_description, - device, - _use_deterministic_compute=True, - use_mixed_precision=True, - loss_scaler=legacy_loss_scaler, - ) - legacy_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) - legacy_losses.append(leg_loss.cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) - - -@pytest.mark.parametrize("gradient_accumulation_steps", [(1), (4), (7)]) -def testToyBERTModelGradientAccumulationLegacyExperimental(gradient_accumulation_steps): - # Common setup - total_steps = 128 - device = "cuda" - seed = 1 - - # EXPERIMENTAL IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.AdamConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - experimental_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - loss = trainer.train_step(*sample_input) - experimental_losses.append(loss.cpu().item()) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(optim_config.lr) - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - "AdamOptimizer", - None, - learning_rate_description, - device, - _use_deterministic_compute=True, - gradient_accumulation_steps=gradient_accumulation_steps, - ) - legacy_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) - legacy_losses.append(leg_loss.cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) - - -@pytest.mark.parametrize( - "params, legacy_optim_map", - [ - # Change the hyper parameters for all parameters - ([], legacy_optim_params_a), - # Change the hyperparameters for a subset of hardcoded parameters - ( - [ - { - "params": ["bert.embeddings.LayerNorm.bias", "bert.embeddings.LayerNorm.weight"], - "alpha": 0.9, - "beta": 0.999, - "lambda_coef": 0.0, - "epsilon": 1e-6, - "do_bias_correction": False, - } - ], - legacy_optim_params_b, - ), - # Change the hyperparameters for a generated set of paramers - (optimizer_parameters(load_bert_onnx_model()), legacy_optim_params_c), - ], -) -def testToyBERTModelLegacyExperimentalCustomOptimParameters(params, legacy_optim_map): - # Common setup - total_steps = 128 - device = "cuda" - seed = 1 - - # EXPERIMENTAL API - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - model_desc = bert_model_description() - model = load_bert_onnx_model() - - optim_config = optim.AdamConfig( - params, alpha=0.9, beta=0.999, lambda_coef=0.01, epsilon=1e-6, do_bias_correction=False - ) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - experimental_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(trainer.optim_config.lr) - - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - "AdamOptimizer", - legacy_optim_map, - learning_rate_description, - device, - _use_deterministic_compute=True, - ) - legacy_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - legacy_sample_input = [*sample_input, learning_rate] - legacy_losses.append(legacy_trainer.train_step(legacy_sample_input).cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py deleted file mode 100644 index d366f2cb26557..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py +++ /dev/null @@ -1,722 +0,0 @@ -from unittest.mock import Mock, patch - -import numpy as np -import onnx -import pytest -import torch -from _test_commons import _load_pytorch_transformer_model - -from onnxruntime.training import _checkpoint_storage, amp, checkpoint, optim, orttrainer # noqa: F401 - -# Helper functions - - -def _create_trainer(zero_enabled=False): - """Cerates a simple ORTTrainer for ORTTrainer functional tests""" - - device = "cuda" - optim_config = optim.LambConfig(lr=0.1) - opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} - if zero_enabled: - opts["distributed"] = { - "world_rank": 0, - "world_size": 1, - "horizontal_parallel_size": 1, - "data_parallel_size": 1, - "allreduce_post_accumulation": True, - "deepspeed_zero_optimization": {"stage": 1}, - } - model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) - trainer = orttrainer.ORTTrainer( - model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(opts) - ) - - return trainer - - -class _training_session_mock: # noqa: N801 - """Mock object for the ORTTrainer _training_session member""" - - def __init__(self, model_states, optimizer_states, partition_info): - self.model_states = model_states - self.optimizer_states = optimizer_states - self.partition_info = partition_info - - def get_model_state(self, include_mixed_precision_weights=False): - return self.model_states - - def get_optimizer_state(self): - return self.optimizer_states - - def get_partition_info_map(self): - return self.partition_info - - -def _get_load_state_dict_strict_error_arguments(): - """Return a list of tuples that can be used as parameters for test_load_state_dict_errors_when_model_key_missing - - Construct a list of tuples (training_session_state_dict, input_state_dict, error_arguments) - The load_state_dict function will compare the two state dicts (training_session_state_dict, input_state_dict) and - throw a runtime error with the missing/unexpected keys. The error arguments capture these missing/unexpected keys. - """ - - training_session_state_dict = { - "model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}, - "optimizer": { - "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(5)}, - }, - } - - # input state dictionaries - precision_key_missing = {"model": {}, "optimizer": {}} - precision_key_unexpected = {"model": {"full_precision": {}, "mixed_precision": {}}, "optimizer": {}} - model_state_key_missing = {"model": {"full_precision": {}}, "optimizer": {}} - model_state_key_unexpected = {"model": {"full_precision": {"a": 2, "b": 3, "c": 4}}, "optimizer": {}} - optimizer_model_state_key_missing = {"model": {"full_precision": {"a": 2, "b": 3}}, "optimizer": {}} - optimizer_model_state_key_unexpected = { - "model": {"full_precision": {"a": 2, "b": 3}}, - "optimizer": {"a": {}, "shared_optimizer_state": {}, "b": {}}, - } - optimizer_state_key_missing = { - "model": {"full_precision": {"a": 2, "b": 3}}, - "optimizer": {"a": {}, "shared_optimizer_state": {"step": np.arange(5)}}, - } - optimizer_state_key_unexpected = { - "model": {"full_precision": {"a": 2, "b": 3}}, - "optimizer": { - "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(5), "another_step": np.arange(1)}, - }, - } - - input_arguments = [ - (training_session_state_dict, precision_key_missing, ["full_precision"]), - (training_session_state_dict, precision_key_unexpected, ["mixed_precision"]), - (training_session_state_dict, model_state_key_missing, ["a", "b"]), - (training_session_state_dict, model_state_key_unexpected, ["c"]), - (training_session_state_dict, optimizer_model_state_key_missing, ["a", "shared_optimizer_state"]), - (training_session_state_dict, optimizer_model_state_key_unexpected, ["b"]), - (training_session_state_dict, optimizer_state_key_missing, ["Moment_1", "Moment_2"]), - (training_session_state_dict, optimizer_state_key_unexpected, ["another_step"]), - ] - - return input_arguments - - -# Tests - - -def test_empty_state_dict_when_training_session_uninitialized(): - trainer = _create_trainer() - with pytest.warns(UserWarning) as user_warning: - state_dict = trainer.state_dict() - - assert len(state_dict.keys()) == 0 - assert ( - user_warning[0].message.args[0] == "ONNX Runtime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling ORTTrainer.state_dict()." - ) - - -@patch("onnx.ModelProto") -def test_training_session_provides_empty_model_states(onnx_model_mock): - trainer = _create_trainer() - training_session_mock = _training_session_mock({}, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert len(state_dict["model"].keys()) == 0 - - -@patch("onnx.ModelProto") -def test_training_session_provides_model_states(onnx_model_mock): - trainer = _create_trainer() - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - training_session_mock = _training_session_mock(model_states, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all() - assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all() - - -@patch("onnx.ModelProto") -def test_training_session_provides_model_states_pytorch_format(onnx_model_mock): - trainer = _create_trainer() - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - training_session_mock = _training_session_mock(model_states, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict(pytorch_format=True) - assert torch.all(torch.eq(state_dict["a"], torch.tensor(np.arange(5)))) - assert torch.all(torch.eq(state_dict["b"], torch.tensor(np.arange(7)))) - - -@patch("onnx.ModelProto") -def test_onnx_graph_provides_frozen_model_states(onnx_model_mock): - trainer = _create_trainer() - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - training_session_mock = _training_session_mock(model_states, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - trainer.options.utils.frozen_weights = ["a_frozen_weight", "a_float16_weight"] - trainer._onnx_model.graph.initializer = [ - onnx.numpy_helper.from_array(np.array([1, 2, 3], dtype=np.float32), "a_frozen_weight"), - onnx.numpy_helper.from_array(np.array([4, 5, 6], dtype=np.float32), "a_non_fronzen_weight"), - onnx.numpy_helper.from_array(np.array([7, 8, 9], dtype=np.float16), "a_float16_weight"), - ] - - state_dict = trainer.state_dict() - assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all() - assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all() - assert (state_dict["model"]["full_precision"]["a_frozen_weight"] == np.array([1, 2, 3], dtype=np.float32)).all() - assert "a_non_fronzen_weight" not in state_dict["model"]["full_precision"] - assert (state_dict["model"]["full_precision"]["a_float16_weight"] == np.array([7, 8, 9], dtype=np.float32)).all() - - -@patch("onnx.ModelProto") -def test_training_session_provides_empty_optimizer_states(onnx_model_mock): - trainer = _create_trainer() - training_session_mock = _training_session_mock({}, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert len(state_dict["optimizer"].keys()) == 0 - - -@patch("onnx.ModelProto") -def test_training_session_provides_optimizer_states(onnx_model_mock): - trainer = _create_trainer() - optimizer_states = { - "model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(1)}, - } - training_session_mock = _training_session_mock({}, optimizer_states, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert (state_dict["optimizer"]["model_weight"]["Moment_1"] == np.arange(5)).all() - assert (state_dict["optimizer"]["model_weight"]["Moment_2"] == np.arange(7)).all() - assert (state_dict["optimizer"]["shared_optimizer_state"]["step"] == np.arange(1)).all() - - -@patch("onnx.ModelProto") -def test_training_session_provides_optimizer_states_pytorch_format(onnx_model_mock): - trainer = _create_trainer() - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - optimizer_states = { - "model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(1)}, - } - training_session_mock = _training_session_mock(model_states, optimizer_states, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict(pytorch_format=True) - assert "optimizer" not in state_dict - - -@patch("onnx.ModelProto") -def test_training_session_provides_empty_partition_info_map(onnx_model_mock): - trainer = _create_trainer(zero_enabled=True) - training_session_mock = _training_session_mock({}, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert len(state_dict["partition_info"].keys()) == 0 - - -@patch("onnx.ModelProto") -def test_training_session_provides_partition_info_map(onnx_model_mock): - trainer = _create_trainer(zero_enabled=True) - partition_info = {"a": {"original_dim": [1, 2, 3]}} - training_session_mock = _training_session_mock({}, {}, partition_info) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert state_dict["partition_info"]["a"]["original_dim"] == [1, 2, 3] - - -@patch("onnx.ModelProto") -def test_training_session_provides_all_states(onnx_model_mock): - trainer = _create_trainer(zero_enabled=True) - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - optimizer_states = { - "model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(1)}, - } - partition_info = {"a": {"original_dim": [1, 2, 3]}} - training_session_mock = _training_session_mock(model_states, optimizer_states, partition_info) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all() - assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all() - assert (state_dict["optimizer"]["model_weight"]["Moment_1"] == np.arange(5)).all() - assert (state_dict["optimizer"]["model_weight"]["Moment_2"] == np.arange(7)).all() - assert (state_dict["optimizer"]["shared_optimizer_state"]["step"] == np.arange(1)).all() - assert state_dict["partition_info"]["a"]["original_dim"] == [1, 2, 3] - - -def test_load_state_dict_holds_when_training_session_not_initialized(): - trainer = _create_trainer() - state_dict = { - "model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}, - "optimizer": { - "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(5)}, - }, - } - assert not trainer._load_state_dict - state_dict = trainer.load_state_dict(state_dict) - assert trainer._load_state_dict - - -@pytest.mark.parametrize( - "state_dict, input_state_dict, error_key", - [ - ( - {"model": {}, "optimizer": {}}, - {"model": {}, "optimizer": {}, "trainer_options": {"optimizer_name": "LambOptimizer"}}, - "train_step_info", - ), - ( - {"optimizer": {}, "train_step_info": {"optimization_step": 0, "step": 0}}, - { - "optimizer": {}, - "trainer_options": {"optimizer_name": "LambOptimizer"}, - "train_step_info": {"optimization_step": 0, "step": 0}, - }, - "model", - ), - ( - {"model": {}, "train_step_info": {"optimization_step": 0, "step": 0}}, - { - "model": {}, - "trainer_options": {"optimizer_name": "LambOptimizer"}, - "train_step_info": {"optimization_step": 0, "step": 0}, - }, - "optimizer", - ), - ], -) -def test_load_state_dict_warns_when_model_optimizer_key_missing(state_dict, input_state_dict, error_key): - trainer = _create_trainer() - trainer._training_session = _training_session_mock({}, {}, {}) - trainer.state_dict = Mock(return_value=state_dict) - trainer._update_onnx_model_initializers = Mock() - trainer._init_session = Mock() - with patch("onnx.ModelProto") as onnx_model_mock: - trainer._onnx_model = onnx_model_mock() - trainer._onnx_model.graph.initializer = [] - with pytest.warns(UserWarning) as user_warning: - trainer.load_state_dict(input_state_dict) - - assert user_warning[0].message.args[0] == f"Missing key: {error_key} in state_dict" - - -@pytest.mark.parametrize("state_dict, input_state_dict, error_keys", _get_load_state_dict_strict_error_arguments()) -def test_load_state_dict_errors_when_state_dict_mismatch(state_dict, input_state_dict, error_keys): - trainer = _create_trainer() - trainer._training_session = _training_session_mock({}, {}, {}) - trainer.state_dict = Mock(return_value=state_dict) - with pytest.raises(RuntimeError) as runtime_error: - trainer.load_state_dict(input_state_dict) - - assert any(key in str(runtime_error.value) for key in error_keys) - - -@patch("onnx.ModelProto") -def test_load_state_dict_loads_the_states_and_inits_training_session(onnx_model_mock): - trainer = _create_trainer() - training_session_state_dict = { - "model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}, - "optimizer": { - "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(1)}, - }, - } - - input_state_dict = { - "model": {"full_precision": {"a": np.array([1, 2]), "b": np.array([3, 4])}}, - "optimizer": { - "a": {"Moment_1": np.array([5, 6]), "Moment_2": np.array([7, 8])}, - "shared_optimizer_state": {"step": np.array([9])}, - }, - "trainer_options": {"optimizer_name": "LambOptimizer"}, - } - trainer._training_session = _training_session_mock({}, {}, {}) - trainer.state_dict = Mock(return_value=training_session_state_dict) - trainer._onnx_model = onnx_model_mock() - trainer._onnx_model.graph.initializer = [ - onnx.numpy_helper.from_array(np.arange(20, dtype=np.float32), "a"), - onnx.numpy_helper.from_array(np.arange(25, dtype=np.float32), "b"), - ] - trainer._update_onnx_model_initializers = Mock() - trainer._init_session = Mock() - - trainer.load_state_dict(input_state_dict) - - loaded_initializers, _ = trainer._update_onnx_model_initializers.call_args - state_dict_to_load, _ = trainer._init_session.call_args - - assert "a" in loaded_initializers[0] - assert (loaded_initializers[0]["a"] == np.array([1, 2])).all() - assert "b" in loaded_initializers[0] - assert (loaded_initializers[0]["b"] == np.array([3, 4])).all() - - assert (state_dict_to_load[0]["a"]["Moment_1"] == np.array([5, 6])).all() - assert (state_dict_to_load[0]["a"]["Moment_2"] == np.array([7, 8])).all() - assert (state_dict_to_load[0]["shared_optimizer_state"]["step"] == np.array([9])).all() - - -@patch("onnxruntime.training._checkpoint_storage.save") -def test_save_checkpoint_calls_checkpoint_storage_save(save_mock): - trainer = _create_trainer() - state_dict = {"model": {}, "optimizer": {}} - trainer.state_dict = Mock(return_value=state_dict) - - trainer.save_checkpoint("abc") - - save_args, _ = save_mock.call_args - assert "model" in save_args[0] - assert not bool(save_args[0]["model"]) - assert "optimizer" in save_args[0] - assert not bool(save_args[0]["optimizer"]) - assert save_args[1] == "abc" - - -@patch("onnxruntime.training._checkpoint_storage.save") -def test_save_checkpoint_exclude_optimizer_states(save_mock): - trainer = _create_trainer() - state_dict = {"model": {}, "optimizer": {}} - trainer.state_dict = Mock(return_value=state_dict) - - trainer.save_checkpoint("abc", include_optimizer_states=False) - - save_args, _ = save_mock.call_args - assert "model" in save_args[0] - assert not bool(save_args[0]["model"]) - assert "optimizer" not in save_args[0] - assert save_args[1] == "abc" - - -@patch("onnxruntime.training._checkpoint_storage.save") -def test_save_checkpoint_user_dict(save_mock): - trainer = _create_trainer() - state_dict = {"model": {}, "optimizer": {}} - trainer.state_dict = Mock(return_value=state_dict) - - trainer.save_checkpoint("abc", user_dict={"abc": np.arange(4)}) - - save_args, _ = save_mock.call_args - assert "user_dict" in save_args[0] - assert save_args[0]["user_dict"] == _checkpoint_storage.to_serialized_hex({"abc": np.arange(4)}) - - -@patch("onnxruntime.training._checkpoint_storage.load") -@patch("onnxruntime.training.checkpoint.aggregate_checkpoints") -def test_load_checkpoint(aggregate_checkpoints_mock, load_mock): - trainer = _create_trainer() - trainer_options = { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - } - state_dict = { - "model": {}, - "optimizer": {}, - "trainer_options": { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - }, - } - trainer.load_state_dict = Mock() - - load_mock.side_effect = [trainer_options, state_dict] - trainer.load_checkpoint("abc") - - args_list = load_mock.call_args_list - load_args, load_kwargs = args_list[0] - assert load_args[0] == "abc" - assert load_kwargs["key"] == "trainer_options" - load_args, load_kwargs = args_list[1] - assert load_args[0] == "abc" - assert "key" not in load_kwargs - assert not aggregate_checkpoints_mock.called - - -@patch("onnxruntime.training._checkpoint_storage.load") -@patch("onnxruntime.training.checkpoint.aggregate_checkpoints") -@pytest.mark.parametrize( - "trainer_options", - [ - { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(4), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(4), - "zero_stage": np.int64(1), - }, - { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(1), - }, - { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(1), - }, - ], -) -def test_load_checkpoint_aggregation_required_zero_enabled(aggregate_checkpoints_mock, load_mock, trainer_options): - trainer = _create_trainer() - trainer.load_state_dict = Mock() - - load_mock.side_effect = [trainer_options] - trainer.load_checkpoint("abc") - - args_list = load_mock.call_args_list - load_args, load_kwargs = args_list[0] - assert load_args[0] == "abc" - assert load_kwargs["key"] == "trainer_options" - assert aggregate_checkpoints_mock.called - call_args, _ = aggregate_checkpoints_mock.call_args - assert call_args[0] == tuple(["abc"]) - - -@patch("onnxruntime.training._checkpoint_storage.load") -@patch("onnxruntime.training.checkpoint.aggregate_checkpoints") -def test_load_checkpoint_user_dict(aggregate_checkpoints_mock, load_mock): - trainer = _create_trainer() - trainer_options = { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - } - state_dict = { - "model": {}, - "optimizer": {}, - "trainer_options": { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - }, - "user_dict": _checkpoint_storage.to_serialized_hex({"array": torch.tensor(np.arange(5))}), - } - trainer.load_state_dict = Mock() - - load_mock.side_effect = [trainer_options, state_dict] - user_dict = trainer.load_checkpoint("abc") - - assert torch.all(torch.eq(user_dict["array"], torch.tensor(np.arange(5)))) - - -@patch("onnxruntime.training._checkpoint_storage.load") -def test_checkpoint_aggregation(load_mock): - trainer_options1 = { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(2), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(2), - "zero_stage": np.int64(1), - "optimizer_name": b"Adam", - } - trainer_options2 = { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(1), - "world_size": np.int64(2), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(2), - "zero_stage": np.int64(1), - "optimizer_name": b"Adam", - } - - state_dict1 = { - "model": {"full_precision": {"optimizer_sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}}, - "optimizer": { - "optimizer_sharded": { - "Moment_1": np.array([9, 8, 7]), - "Moment_2": np.array([99, 88, 77]), - "Step": np.array([5]), - }, - "non_sharded": { - "Moment_1": np.array([666, 555, 444]), - "Moment_2": np.array([6666, 5555, 4444]), - "Step": np.array([55]), - }, - }, - "trainer_options": { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - "optimizer_name": b"Adam", - }, - "partition_info": {"optimizer_sharded": {"original_dim": np.array([2, 3])}}, - } - - state_dict2 = { - "model": {"full_precision": {"optimizer_sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}}, - "optimizer": { - "optimizer_sharded": { - "Moment_1": np.array([6, 5, 4]), - "Moment_2": np.array([66, 55, 44]), - "Step": np.array([5]), - }, - "non_sharded": { - "Moment_1": np.array([666, 555, 444]), - "Moment_2": np.array([6666, 5555, 4444]), - "Step": np.array([55]), - }, - }, - "trainer_options": { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(1), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - "optimizer_name": b"Adam", - }, - "partition_info": {"optimizer_sharded": {"original_dim": np.array([2, 3])}}, - } - - load_mock.side_effect = [trainer_options1, trainer_options2, trainer_options1, state_dict1, state_dict2] - state_dict = checkpoint.aggregate_checkpoints(["abc", "def"], pytorch_format=False) - - assert (state_dict["model"]["full_precision"]["optimizer_sharded"] == np.array([1, 2, 3])).all() - assert (state_dict["model"]["full_precision"]["non_sharded"] == np.array([11, 22, 33])).all() - assert (state_dict["optimizer"]["optimizer_sharded"]["Moment_1"] == np.array([[9, 8, 7], [6, 5, 4]])).all() - assert (state_dict["optimizer"]["optimizer_sharded"]["Moment_2"] == np.array([[99, 88, 77], [66, 55, 44]])).all() - assert (state_dict["optimizer"]["optimizer_sharded"]["Step"] == np.array([5])).all() - assert (state_dict["optimizer"]["non_sharded"]["Moment_1"] == np.array([666, 555, 444])).all() - assert (state_dict["optimizer"]["non_sharded"]["Moment_2"] == np.array([6666, 5555, 4444])).all() - assert (state_dict["optimizer"]["non_sharded"]["Step"] == np.array([55])).all() - - assert state_dict["trainer_options"]["mixed_precision"] is False - assert state_dict["trainer_options"]["world_rank"] == 0 - assert state_dict["trainer_options"]["world_size"] == 1 - assert state_dict["trainer_options"]["horizontal_parallel_size"] == 1 - assert state_dict["trainer_options"]["data_parallel_size"] == 1 - assert state_dict["trainer_options"]["zero_stage"] == 0 - assert state_dict["trainer_options"]["optimizer_name"] == b"Adam" - - -@patch("onnxruntime.training._checkpoint_storage.load") -def test_checkpoint_aggregation_mixed_precision(load_mock): - trainer_options1 = { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(0), - "world_size": np.int64(2), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(2), - "zero_stage": np.int64(1), - "optimizer_name": b"Adam", - } - trainer_options2 = { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(1), - "world_size": np.int64(2), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(2), - "zero_stage": np.int64(1), - "optimizer_name": b"Adam", - } - - state_dict1 = { - "model": {"full_precision": {"sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}}, - "optimizer": { - "sharded": {"Moment_1": np.array([9, 8, 7]), "Moment_2": np.array([99, 88, 77]), "Step": np.array([5])}, - "non_sharded": { - "Moment_1": np.array([666, 555, 444]), - "Moment_2": np.array([6666, 5555, 4444]), - "Step": np.array([55]), - }, - }, - "trainer_options": { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - "optimizer_name": b"Adam", - }, - "partition_info": {"sharded": {"original_dim": np.array([2, 3])}}, - } - - state_dict2 = { - "model": {"full_precision": {"sharded": np.array([4, 5, 6]), "non_sharded": np.array([11, 22, 33])}}, - "optimizer": { - "sharded": {"Moment_1": np.array([6, 5, 4]), "Moment_2": np.array([66, 55, 44]), "Step": np.array([5])}, - "non_sharded": { - "Moment_1": np.array([666, 555, 444]), - "Moment_2": np.array([6666, 5555, 4444]), - "Step": np.array([55]), - }, - }, - "trainer_options": { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(1), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - "optimizer_name": b"Adam", - }, - "partition_info": {"sharded": {"original_dim": np.array([2, 3])}}, - } - - load_mock.side_effect = [trainer_options1, trainer_options2, trainer_options1, state_dict1, state_dict2] - state_dict = checkpoint.aggregate_checkpoints(["abc", "def"], pytorch_format=False) - - assert (state_dict["model"]["full_precision"]["sharded"] == np.array([[1, 2, 3], [4, 5, 6]])).all() - assert (state_dict["model"]["full_precision"]["non_sharded"] == np.array([11, 22, 33])).all() - assert (state_dict["optimizer"]["sharded"]["Moment_1"] == np.array([[9, 8, 7], [6, 5, 4]])).all() - assert (state_dict["optimizer"]["sharded"]["Moment_2"] == np.array([[99, 88, 77], [66, 55, 44]])).all() - assert (state_dict["optimizer"]["sharded"]["Step"] == np.array([5])).all() - assert (state_dict["optimizer"]["non_sharded"]["Moment_1"] == np.array([666, 555, 444])).all() - assert (state_dict["optimizer"]["non_sharded"]["Moment_2"] == np.array([6666, 5555, 4444])).all() - assert (state_dict["optimizer"]["non_sharded"]["Step"] == np.array([55])).all() - - assert state_dict["trainer_options"]["mixed_precision"] is True - assert state_dict["trainer_options"]["world_rank"] == 0 - assert state_dict["trainer_options"]["world_size"] == 1 - assert state_dict["trainer_options"]["horizontal_parallel_size"] == 1 - assert state_dict["trainer_options"]["data_parallel_size"] == 1 - assert state_dict["trainer_options"]["zero_stage"] == 0 - assert state_dict["trainer_options"]["optimizer_name"] == b"Adam" diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py deleted file mode 100644 index fa13625f0ddac..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ /dev/null @@ -1,2460 +0,0 @@ -import inspect -import os -import tempfile -from functools import partial - -import _test_commons -import _test_helpers -import onnx -import pytest -import torch -import torch.nn.functional as F -from numpy.testing import assert_allclose -from packaging.version import Version as StrictVersion - -from onnxruntime import SessionOptions, set_seed -from onnxruntime.capi.ort_trainer import LossScaler as Legacy_LossScaler -from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer -from onnxruntime.training import PropagateCastOpsStrategy, TrainStepInfo, _utils, amp -from onnxruntime.training import model_desc_validation as md_val -from onnxruntime.training import optim, orttrainer, orttrainer_options - -############################################################################### -# Testing starts here ######################################################### -############################################################################### - -pytorch_110 = StrictVersion(".".join(torch.__version__.split(".")[:2])) >= StrictVersion("1.10.0") - - -def get_model_opset(model_onnx): - for op in model_onnx.opset_import: - if op.domain == "": - return op.version - return None - - -@pytest.mark.parametrize( - "test_input", - [({}), ({"batch": {}, "device": {}, "distributed": {}, "mixed_precision": {}, "utils": {}, "_internal_use": {}})], -) -def testORTTrainerOptionsDefaultValues(test_input): - """Test different ways of using default values for incomplete input""" - - expected_values = { - "batch": {"gradient_accumulation_steps": 1}, - "device": {"id": "cuda", "mem_limit": 0}, - "distributed": { - "world_rank": 0, - "world_size": 1, - "local_rank": 0, - "data_parallel_size": 1, - "horizontal_parallel_size": 1, - "pipeline_parallel": { - "pipeline_parallel_size": 1, - "num_pipeline_micro_batches": 1, - "pipeline_cut_info_string": "", - "sliced_schema": {}, - "sliced_axes": {}, - "sliced_tensor_names": [], - }, - "allreduce_post_accumulation": False, - "deepspeed_zero_optimization": { - "stage": 0, - }, - "enable_adasum": False, - }, - "lr_scheduler": None, - "mixed_precision": {"enabled": False, "loss_scaler": None}, - "graph_transformer": { - "attn_dropout_recompute": False, - "gelu_recompute": False, - "transformer_layer_recompute": False, - "number_recompute_layers": 0, - "propagate_cast_ops_config": {"strategy": PropagateCastOpsStrategy.FLOOD_FILL, "level": 1, "allow": []}, - }, - "utils": { - "frozen_weights": [], - "grad_norm_clip": True, - "memory_efficient_gradient": False, - "run_symbolic_shape_infer": False, - }, - "debug": { - "deterministic_compute": False, - "check_model_export": False, - "graph_save_paths": { - "model_after_graph_transforms_path": "", - "model_with_gradient_graph_path": "", - "model_with_training_graph_path": "", - "model_with_training_graph_after_optimization_path": "", - }, - }, - "_internal_use": { - "enable_internal_postprocess": True, - "extra_postprocess": None, - "onnx_opset_version": 14, - "enable_onnx_contrib_ops": True, - }, - "provider_options": {}, - "session_options": None, - } - - actual_values = orttrainer_options.ORTTrainerOptions(test_input) - assert actual_values._validated_opts == expected_values - - -@pytest.mark.parametrize( - "input,error_msg", - [ - ( - {"mixed_precision": {"enabled": 1}}, - "Invalid options: {'mixed_precision': [{'enabled': ['must be of boolean type']}]}", - ) - ], -) -def testORTTrainerOptionsInvalidMixedPrecisionEnabledSchema(input, error_msg): - """Test an invalid input based on schema validation error message""" - - with pytest.raises(ValueError) as e: - orttrainer_options.ORTTrainerOptions(input) - assert str(e.value) == error_msg - - -@pytest.mark.parametrize( - "input_dict,input_dtype,output_dtype", - [ - ( - {"inputs": [("in0", [])], "outputs": [("out0", []), ("out1", [])]}, - (torch.int,), - ( - torch.float, - torch.int32, - ), - ), - ({"inputs": [("in0", ["batch", 2, 3])], "outputs": [("out0", [], True)]}, (torch.int8,), (torch.int16,)), - ( - { - "inputs": [ - ("in0", []), - ("in1", [1]), - ("in2", [1, 2]), - ("in3", [1000, "dyn_ax1"]), - ("in4", ["dyn_ax1", "dyn_ax2", "dyn_ax3"]), - ], - "outputs": [("out0", [], True), ("out1", [1], False), ("out2", [1, "dyn_ax1", 3])], - }, - ( - torch.float, - torch.uint8, - torch.bool, - torch.double, - torch.half, - ), - (torch.float, torch.float, torch.int64), - ), - ], -) -def testORTTrainerModelDescValidSchemas(input_dict, input_dtype, output_dtype): - r"""Test different ways of using default values for incomplete input""" - - model_description = md_val._ORTTrainerModelDesc(input_dict) - - # Validating hard-coded learning rate description - assert model_description.learning_rate.name == md_val.LEARNING_RATE_IO_DESCRIPTION_NAME - assert model_description.learning_rate.shape == [1] - assert model_description.learning_rate.dtype == torch.float32 - - # Validating model description from user - for idx, i_desc in enumerate(model_description.inputs): - assert isinstance(i_desc, model_description._InputDescription) - assert len(i_desc) == 2 - assert input_dict["inputs"][idx][0] == i_desc.name - assert input_dict["inputs"][idx][1] == i_desc.shape - for idx, o_desc in enumerate(model_description.outputs): - assert isinstance(o_desc, model_description._OutputDescription) - assert len(o_desc) == 3 - assert input_dict["outputs"][idx][0] == o_desc.name - assert input_dict["outputs"][idx][1] == o_desc.shape - is_loss = input_dict["outputs"][idx][2] if len(input_dict["outputs"][idx]) == 3 else False - assert is_loss == o_desc.is_loss - - # Set all_finite name and check its description - model_description.all_finite = md_val.ALL_FINITE_IO_DESCRIPTION_NAME - assert model_description.all_finite.name == md_val.ALL_FINITE_IO_DESCRIPTION_NAME - assert model_description.all_finite.shape == [1] - assert model_description.all_finite.dtype == torch.bool - - # Set loss_scale_input and check its description - model_description.loss_scale_input = md_val.LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME - assert model_description.loss_scale_input.name == md_val.LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME - assert model_description.loss_scale_input.shape == [] - assert model_description.loss_scale_input.dtype == torch.float32 - - # Append type to inputs/outputs tuples - for idx, i_desc in enumerate(model_description.inputs): # noqa: B007 - model_description.add_type_to_input_description(idx, input_dtype[idx]) - for idx, o_desc in enumerate(model_description.outputs): # noqa: B007 - model_description.add_type_to_output_description(idx, output_dtype[idx]) - - # Verify inputs/outputs tuples are replaced by the typed counterparts - for idx, i_desc in enumerate(model_description.inputs): - assert isinstance(i_desc, model_description._InputDescriptionTyped) - assert input_dtype[idx] == i_desc.dtype - for idx, o_desc in enumerate(model_description.outputs): - assert isinstance(o_desc, model_description._OutputDescriptionTyped) - assert output_dtype[idx] == o_desc.dtype - - -@pytest.mark.parametrize( - "input_dict,error_msg", - [ - ( - {"inputs": [(True, [])], "outputs": [(True, [])]}, - "Invalid model_desc: {'inputs': [{0: ['the first element of the tuple (aka name) must be a string']}], " - "'outputs': [{0: ['the first element of the tuple (aka name) must be a string']}]}", - ), - ( - {"inputs": [("in1", None)], "outputs": [("out1", None)]}, - "Invalid model_desc: {'inputs': [{0: ['the second element of the tuple (aka shape) must be a list']}], " - "'outputs': [{0: ['the second element of the tuple (aka shape) must be a list']}]}", - ), - ( - {"inputs": [("in1", [])], "outputs": [("out1", [], None)]}, - "Invalid model_desc: {'outputs': [{0: ['the third element of the tuple (aka is_loss) must be a boolean']}]}", - ), - ( - {"inputs": [("in1", [True])], "outputs": [("out1", [True])]}, - "Invalid model_desc: {'inputs': [{0: ['each shape must be either a string or integer']}], " - "'outputs': [{0: ['each shape must be either a string or integer']}]}", - ), - ( - {"inputs": [("in1", [])], "outputs": [("out1", [], True), ("out2", [], True)]}, - "Invalid model_desc: {'outputs': [{1: ['only one is_loss can bet set to True']}]}", - ), - ( - {"inputz": [("in1", [])], "outputs": [("out1", [], True)]}, - "Invalid model_desc: {'inputs': ['required field'], 'inputz': ['unknown field']}", - ), - ( - {"inputs": [("in1", [])], "outputz": [("out1", [], True)]}, - "Invalid model_desc: {'outputs': ['required field'], 'outputz': ['unknown field']}", - ), - ], -) -def testORTTrainerModelDescInvalidSchemas(input_dict, error_msg): - r"""Test different ways of using default values for incomplete input""" - with pytest.raises(ValueError) as e: - md_val._ORTTrainerModelDesc(input_dict) - assert str(e.value) == error_msg - - -def testDynamicLossScaler(): - rtol = 1e-7 - default_scaler = amp.loss_scaler.DynamicLossScaler() - - # Initial state - train_step_info = orttrainer.TrainStepInfo(optim.LambConfig()) - assert_allclose(default_scaler.loss_scale, float(1 << 16), rtol=rtol, err_msg="loss scale mismatch") - assert default_scaler.up_scale_window == 2000 - assert_allclose(default_scaler.min_loss_scale, 1.0, rtol=rtol, err_msg="min loss scale mismatch") - assert_allclose(default_scaler.max_loss_scale, float(1 << 24), rtol=rtol, err_msg="max loss scale mismatch") - - # Performing 9*2000 updates to cover all branches of LossScaler.update(train_step_info.all_finite=True) - loss_scale = float(1 << 16) - for cycles in range(1, 10): - # 1999 updates without overflow produces 1999 stable steps - for i in range(1, 2000): - new_loss_scale = default_scaler.update(train_step_info) - assert default_scaler._stable_steps_count == i - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg=f"loss scale mismatch at update {i}") - - # 2000th update without overflow doubles the loss and zero stable steps until max_loss_scale is reached - new_loss_scale = default_scaler.update(train_step_info) - if cycles <= 8: - loss_scale *= 2 - assert default_scaler._stable_steps_count == 0 - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") - - # After 8 cycles, loss scale should be float(1 << 16)*(2**8) - assert_allclose(new_loss_scale, float(1 << 16) * (2**8), rtol=rtol, err_msg="loss scale mismatch") - - # After 9 cycles, loss scale reaches max_loss_scale and it is not doubled from that point on - loss_scale = float(1 << 16) * (2**8) - for count in range(1, 2050): - new_loss_scale = default_scaler.update(train_step_info) - assert default_scaler._stable_steps_count == (count % 2000) - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") - - # Setting train_step_info.all_finite = False to test down scaling - train_step_info.all_finite = False - - # Performing 24 updates to half the loss scale each time - loss_scale = float(1 << 16) * (2**8) - for count in range(1, 25): # noqa: B007 - new_loss_scale = default_scaler.update(train_step_info) - loss_scale /= 2 - assert default_scaler._stable_steps_count == 0 - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") - - # After 24 updates with gradient overflow, loss scale is 1.0 - assert_allclose(new_loss_scale, 1.0, rtol=rtol, err_msg="loss scale mismatch") - - # After 25 updates, min_loss_scale is reached and loss scale is not halfed from that point on - for count in range(1, 5): # noqa: B007 - new_loss_scale = default_scaler.update(train_step_info) - assert default_scaler._stable_steps_count == 0 - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") - - -def testDynamicLossScalerCustomValues(): - rtol = 1e-7 - scaler = amp.loss_scaler.DynamicLossScaler( - automatic_update=False, loss_scale=3, up_scale_window=7, min_loss_scale=5, max_loss_scale=10 - ) - assert scaler.automatic_update is False - assert_allclose(scaler.loss_scale, 3, rtol=rtol, err_msg="loss scale mismatch") - assert_allclose(scaler.min_loss_scale, 5, rtol=rtol, err_msg="min loss scale mismatch") - assert_allclose(scaler.max_loss_scale, 10, rtol=rtol, err_msg="max loss scale mismatch") - assert scaler.up_scale_window == 7 - - -def testTrainStepInfo(): - """Test valid initializations of TrainStepInfo""" - - optimizer_config = optim.LambConfig() - fetches = ["out1", "out2"] - step_info = orttrainer.TrainStepInfo( - optimizer_config=optimizer_config, all_finite=False, fetches=fetches, optimization_step=123, step=456 - ) - assert step_info.optimizer_config == optimizer_config - assert step_info.all_finite is False - assert step_info.fetches == fetches - assert step_info.optimization_step == 123 - assert step_info.step == 456 - - step_info = orttrainer.TrainStepInfo(optimizer_config) - assert step_info.optimizer_config == optimizer_config - assert step_info.all_finite is True - assert step_info.fetches == [] - assert step_info.optimization_step == 0 - assert step_info.step == 0 - - -@pytest.mark.parametrize( - "invalid_input", - [ - (-1), - ("Hello"), - ], -) -def testTrainStepInfoInvalidInput(invalid_input): - """Test invalid initialization of TrainStepInfo""" - optimizer_config = optim.LambConfig() - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config=invalid_input) - - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config, all_finite=invalid_input) - - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config, fetches=invalid_input) - - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config, optimization_step=invalid_input) - - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config, step=invalid_input) - - -@pytest.mark.parametrize( - "optim_name,lr,alpha,default_alpha", - [ - ("AdamOptimizer", 0.1, 0.2, None), - ("LambOptimizer", 0.2, 0.3, None), - ("SGDOptimizer", 0.3, 0.4, None), - ("SGDOptimizer", 0.3, 0.4, 0.5), - ], -) -def testOptimizerConfig(optim_name, lr, alpha, default_alpha): - """Test initialization of _OptimizerConfig""" - defaults = {"lr": lr, "alpha": alpha} - params = [{"params": ["fc1.weight", "fc2.weight"]}] - if default_alpha is not None: - params[0].update({"alpha": default_alpha}) - else: - params[0].update({"alpha": alpha}) - cfg = optim.config._OptimizerConfig(name=optim_name, params=params, defaults=defaults) - - assert cfg.name == optim_name - rtol = 1e-07 - assert_allclose(defaults["lr"], cfg.lr, rtol=rtol, err_msg="lr mismatch") - - # 1:1 mapping between defaults and params's hyper parameters - for param in params: - for k in param: - if k != "params": - assert k in cfg.defaults, "hyper parameter {k} not present in one of the parameter params" - for k in cfg.defaults: - for param in cfg.params: - assert k in param, "hyper parameter {k} not present in one of the parameter params" - - -@pytest.mark.parametrize( - "optim_name,defaults,params", - [ - ("AdamOptimizer", {"lr": -1}, []), # invalid lr - ("FooOptimizer", {"lr": 0.001}, []), # invalid name - ("SGDOptimizer", [], []), # invalid type(defaults) - (optim.AdamConfig, {"lr": 0.003}, []), # invalid type(name) - ("AdamOptimizer", {"lr": None}, []), # missing 'lr' hyper parameter - ("SGDOptimizer", {"lr": 0.004}, {}), # invalid type(params) - # invalid type(params[i]) - ("AdamOptimizer", {"lr": 0.005, "alpha": 2}, [[]]), - # missing 'params' at 'params' - ("AdamOptimizer", {"lr": 0.005, "alpha": 2}, [{"alpha": 1}]), - # missing 'alpha' at 'defaults' - ("AdamOptimizer", {"lr": 0.005}, [{"params": "param1", "alpha": 1}]), - ], -) -def testOptimizerConfigInvalidInputs(optim_name, defaults, params): - """Test invalid initialization of _OptimizerConfig""" - - with pytest.raises(AssertionError): - optim.config._OptimizerConfig(name=optim_name, params=params, defaults=defaults) - - -def testOptimizerConfigSGD(): - """Test initialization of SGD""" - cfg = optim.SGDConfig() - assert cfg.name == "SGDOptimizer" - - rtol = 1e-07 - assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") - - cfg = optim.SGDConfig(lr=0.002) - assert_allclose(0.002, cfg.lr, rtol=rtol, err_msg="lr mismatch") - - # SGD does not support params - with pytest.raises(AssertionError) as e: - params = [{"params": ["layer1.weight"], "lr": 0.1}] - optim.SGDConfig(params=params, lr=0.002) - assert_allclose(0.002, cfg.lr, rtol=rtol, err_msg="lr mismatch") - assert str(e.value) == "'params' must be an empty list for SGD optimizer" - - -def testOptimizerConfigAdam(): - """Test initialization of Adam""" - cfg = optim.AdamConfig() - assert cfg.name == "AdamOptimizer" - - rtol = 1e-7 - assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") - assert_allclose(0.9, cfg.alpha, rtol=rtol, err_msg="alpha mismatch") - assert_allclose(0.999, cfg.beta, rtol=rtol, err_msg="beta mismatch") - assert_allclose(0.0, cfg.lambda_coef, rtol=rtol, err_msg="lambda_coef mismatch") - assert_allclose(1e-8, cfg.epsilon, rtol=rtol, err_msg="epsilon mismatch") - assert_allclose(1.0, cfg.max_norm_clip, rtol=rtol, err_msg="max_norm_clip mismatch") - assert cfg.do_bias_correction is True, "lambda_coef mismatch" - assert cfg.weight_decay_mode == optim.AdamConfig.DecayMode.BEFORE_WEIGHT_UPDATE, "weight_decay_mode mismatch" - - -def testOptimizerConfigLamb(): - """Test initialization of Lamb""" - cfg = optim.LambConfig() - assert cfg.name == "LambOptimizer" - rtol = 1e-7 - assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") - assert_allclose(0.9, cfg.alpha, rtol=rtol, err_msg="alpha mismatch") - assert_allclose(0.999, cfg.beta, rtol=rtol, err_msg="beta mismatch") - assert_allclose(0.0, cfg.lambda_coef, rtol=rtol, err_msg="lambda_coef mismatch") - assert cfg.ratio_min == float("-inf"), "ratio_min mismatch" - assert cfg.ratio_max == float("inf"), "ratio_max mismatch" - assert_allclose(1e-6, cfg.epsilon, rtol=rtol, err_msg="epsilon mismatch") - assert_allclose(1.0, cfg.max_norm_clip, rtol=rtol, err_msg="max_norm_clip mismatch") - assert cfg.do_bias_correction is False, "do_bias_correction mismatch" - - -@pytest.mark.parametrize("optim_name", [("Adam"), ("Lamb")]) -def testOptimizerConfigParams(optim_name): - rtol = 1e-7 - params = [{"params": ["layer1.weight"], "alpha": 0.1}] - if optim_name == "Adam": - cfg = optim.AdamConfig(params=params, alpha=0.2) - elif optim_name == "Lamb": - cfg = optim.LambConfig(params=params, alpha=0.2) - else: - raise ValueError("invalid input") - assert len(cfg.params) == 1, "params should have length 1" - assert_allclose(cfg.params[0]["alpha"], 0.1, rtol=rtol, err_msg="invalid lr on params[0]") - - -@pytest.mark.parametrize("optim_name", [("Adam"), ("Lamb")]) -def testOptimizerConfigInvalidParams(optim_name): - # lr is not supported within params - with pytest.raises(AssertionError) as e: - params = [{"params": ["layer1.weight"], "lr": 0.1}] - if optim_name == "Adam": - optim.AdamConfig(params=params, lr=0.2) - elif optim_name == "Lamb": - optim.LambConfig(params=params, lr=0.2) - else: - raise ValueError("invalid input") - assert str(e.value) == "'lr' is not supported inside params" - - -def testLinearLRSchedulerCreation(): - total_steps = 10 - warmup = 0.05 - - lr_scheduler = optim.lr_scheduler.LinearWarmupLRScheduler(total_steps, warmup) - - # Initial state - assert lr_scheduler.total_steps == total_steps - assert lr_scheduler.warmup == warmup - - -@pytest.mark.parametrize( - "lr_scheduler,expected_values", - [ - (optim.lr_scheduler.ConstantWarmupLRScheduler, [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0]), - ( - optim.lr_scheduler.CosineWarmupLRScheduler, - [ - 0.0, - 0.9763960957919413, - 0.9059835861602854, - 0.7956724530494887, - 0.6563036824392345, - 0.5015739416158049, - 0.34668951940611276, - 0.2068719061737831, - 0.09586187986225325, - 0.0245691111902418, - ], - ), - (optim.lr_scheduler.LinearWarmupLRScheduler, [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.8, 0.6, 0.4, 0.2]), - ( - optim.lr_scheduler.PolyWarmupLRScheduler, - [ - 0.0, - 0.9509018036072144, - 0.9008016032064128, - 0.8507014028056112, - 0.8006012024048097, - 0.750501002004008, - 0.7004008016032064, - 0.6503006012024048, - 0.6002004008016032, - 0.5501002004008015, - ], - ), - ], -) -def testLRSchedulerUpdateImpl(lr_scheduler, expected_values): - # Test tolerance - rtol = 1e-03 - - # Initial state - initial_lr = 1 - total_steps = 10 - warmup = 0.5 - optimizer_config = optim.SGDConfig(lr=initial_lr) - lr_scheduler = lr_scheduler(total_steps, warmup) - - # First half is warmup - for optimization_step in range(total_steps): - # Emulate ORTTRainer.train_step() call that updates its train_step_info - train_step_info = TrainStepInfo(optimizer_config=optimizer_config, optimization_step=optimization_step) - - lr_scheduler._step(train_step_info) - lr_list = lr_scheduler.get_last_lr() - assert len(lr_list) == 1 - assert_allclose(lr_list[0], expected_values[optimization_step], rtol=rtol, err_msg="lr mismatch") - - -def testInstantiateORTTrainerOptions(): - session_options = SessionOptions() - session_options.enable_mem_pattern = False - provider_options = {"EP1": {"key": "val"}} - opts = {"session_options": session_options, "provider_options": provider_options} - opts = orttrainer.ORTTrainerOptions(opts) - assert opts.session_options.enable_mem_pattern is False - assert opts._validated_opts["provider_options"]["EP1"]["key"] == "val" - - -@pytest.mark.parametrize( - "step_fn, lr_scheduler, expected_lr_values, device", - [ - ("train_step", None, None, "cuda"), - ("eval_step", None, None, "cpu"), - ( - "train_step", - optim.lr_scheduler.ConstantWarmupLRScheduler, - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0], - "cpu", - ), - ( - "train_step", - optim.lr_scheduler.CosineWarmupLRScheduler, - [ - 0.0, - 0.2, - 0.4, - 0.6, - 0.8, - 1.0, - 0.9045084971874737, - 0.6545084971874737, - 0.34549150281252633, - 0.09549150281252633, - ], - "cuda", - ), - ( - "train_step", - optim.lr_scheduler.LinearWarmupLRScheduler, - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.8, 0.6, 0.4, 0.2], - "cpu", - ), - ( - "train_step", - optim.lr_scheduler.PolyWarmupLRScheduler, - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.80000002, 0.60000004, 0.40000006000000005, 0.20000007999999997], - "cuda", - ), - ], -) -def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values, device): - total_steps = 1 - initial_lr = 1.0 - rtol = 1e-3 - - # PyTorch Transformer model as example - opts = {"device": {"id": device}} - if lr_scheduler: - total_steps = 10 - opts.update({"lr_scheduler": lr_scheduler(total_steps=total_steps, warmup=0.5)}) - opts = orttrainer.ORTTrainerOptions(opts) - optim_config = optim.LambConfig(lr=initial_lr) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _test_commons._load_pytorch_transformer_model( - device - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - - # Run a train or evaluation step - if step_fn == "eval_step": - data, targets = batcher_fn(val_data, 0) - elif step_fn == "train_step": - data, targets = batcher_fn(train_data, 0) - else: - raise ValueError("Invalid step_fn") - - # Export model to ONNX - if step_fn == "eval_step": - step_fn = trainer.eval_step - output = trainer.eval_step(data, targets) - elif step_fn == "train_step": - step_fn = trainer.train_step - for i in range(total_steps): - output = trainer.train_step(data, targets) - if lr_scheduler: - lr_list = trainer.options.lr_scheduler.get_last_lr() - assert_allclose(lr_list[0], expected_lr_values[i], rtol=rtol, err_msg="lr mismatch") - else: - raise ValueError("Invalid step_fn") - assert trainer._onnx_model is not None - - # Check output shape after train/eval step - for out, desc in zip(output, trainer.model_desc.outputs): - if trainer.loss_fn and desc.is_loss: - continue - assert list(out.size()) == desc.shape - - # Check name, shape and dtype of the first len(forward.parameters) ORT graph inputs - sig = inspect.signature(model.forward) - for i in range(len(sig.parameters.keys())): - input_name = trainer.model_desc.inputs[i][0] - input_dim = trainer.model_desc.inputs[i][1] - input_type = trainer.model_desc.inputs[i][2] - - assert trainer._onnx_model.graph.input[i].name == input_name - for dim_idx, dim in enumerate(trainer._onnx_model.graph.input[i].type.tensor_type.shape.dim): - assert input_dim[dim_idx] == dim.dim_value - assert input_type == _utils.dtype_onnx_to_torch( - trainer._onnx_model.graph.input[i].type.tensor_type.elem_type - ) - - opset = get_model_opset(trainer._onnx_model) - - # Check name, shape and dtype of the ORT graph outputs - for i in range(len(trainer.model_desc.outputs)): - output_name = trainer.model_desc.outputs[i][0] - output_dim = trainer.model_desc.outputs[i][1] - output_type = trainer.model_desc.outputs[i][3] - - assert trainer._onnx_model.graph.output[i].name == output_name - for dim_idx, dim in enumerate(trainer._onnx_model.graph.output[i].type.tensor_type.shape.dim): - if opset is None or opset <= 12: - assert output_dim[dim_idx] == dim.dim_value - assert output_type == _utils.dtype_onnx_to_torch( - trainer._onnx_model.graph.output[i].type.tensor_type.elem_type - ) - - # Save current model as ONNX as a file - file_name = os.path.join("_____temp_onnx_model.onnx") - trainer.save_as_onnx(file_name) - assert os.path.exists(file_name) - with open(file_name, "rb") as f: - bin_str = f.read() - reload_onnx_model = onnx.load_model_from_string(bin_str) - os.remove(file_name) - - # Create a new trainer from persisted ONNX model and compare with original ONNX model - trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config) - step_fn(data, targets) - assert trainer_from_onnx._onnx_model is not None - assert id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model) - assert trainer_from_onnx._onnx_model == trainer._onnx_model - assert trainer_from_onnx._onnx_model.graph == trainer._onnx_model.graph - assert onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph( - trainer._onnx_model.graph - ) - - -@pytest.mark.parametrize("seed, device", [(0, "cpu"), (24, "cuda")]) -def testORTDeterministicCompute(seed, device): - # Common setup - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - {"debug": {"deterministic_compute": True}, "device": {"id": device, "mem_limit": 10 * 1024 * 1024}} - ) - - # Setup for the first ORTTRainer run - torch.manual_seed(seed) - set_seed(seed) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - data, targets = batcher_fn(train_data, 0) - _ = first_trainer.train_step(data, targets) - assert first_trainer._onnx_model is not None - - # Setup for the second ORTTRainer run - torch.manual_seed(seed) - set_seed(seed) - model, _, _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device) - second_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - _ = second_trainer.train_step(data, targets) - assert second_trainer._onnx_model is not None - - # Compare two different instances with identical setup - assert id(first_trainer._onnx_model) != id(second_trainer._onnx_model) - _test_helpers.assert_onnx_weights(first_trainer, second_trainer) - - -@pytest.mark.parametrize( - "seed,device,expected_loss,fetches", - [ - (321, "cuda", [10.5774, 10.4403, 10.4175, 10.2886, 10.2760], False), - (321, "cuda", [10.5774, 10.4403, 10.4175, 10.2886, 10.2760], True), - ], -) -def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches): - return # TODO: re-enable after nondeterminism on backend is fixed. update numbers - - rtol = 1e-3 - total_steps = len(expected_loss) - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - loss_scaler = amp.DynamicLossScaler() - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _test_commons._load_pytorch_transformer_model( - device - ) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - if fetches: - trainer._train_step_info.fetches = ["loss"] - loss = trainer.train_step(data, targets) - else: - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu()) - - # Eval once just to test fetches in action - val_data, val_targets = batcher_fn(val_data, 0) - if fetches: - trainer._train_step_info.fetches = ["loss"] - loss = trainer.eval_step(val_data, val_targets) - trainer._train_step_info.fetches = [] - loss, _ = trainer.eval_step(val_data, val_targets) - - # Compare loss to ground truth computed from current ORTTrainer API - _test_helpers.assert_model_outputs(expected_loss, actual_loss, True, rtol=rtol) - assert trainer._onnx_model is not None - - -def _recompute_data(): - device_capability_major = torch.cuda.get_device_capability()[0] - if device_capability_major == 7: # V100 for Dev machine - expected_loss = { - 12: [10.5598, 10.4591, 10.3477, 10.2726, 10.1945], - 14: [10.54088, 10.498755, 10.386827, 10.338747, 10.262459], - } - return [ - (False, False, False, 0, expected_loss), # no recompute - (True, False, False, 0, expected_loss), # attn_dropout recompute - (False, True, False, 0, expected_loss), # gelu recompute - (False, False, True, 0, expected_loss), # transformer_layer recompute - (False, False, True, 1, expected_loss), # transformer_layer recompute with 1 layer - ] - elif device_capability_major == 5: # M60 for CI machines - expected_loss = { - 12: [10.5445, 10.4389, 10.3480, 10.2627, 10.2113], - 14: [10.5445, 10.4389, 10.3480, 10.2627, 10.2113], - } - return [ - (False, False, False, 0, expected_loss), # no recompute - (True, False, False, 0, expected_loss), # attn_dropout recompute - (False, True, False, 0, expected_loss), # gelu recompute - (False, False, True, 0, expected_loss), # transformer_layer recompute - (False, False, True, 1, expected_loss), # transformer_layer recompute with 1 layer - ] - - -@pytest.mark.parametrize("attn_dropout, gelu, transformer_layer, number_layers, expected_loss", _recompute_data()) -def testORTTrainerRecompute(attn_dropout, gelu, transformer_layer, number_layers, expected_loss): - seed = 321 - device = "cuda" - rtol = 1e-3 - total_steps = len(expected_loss[12]) - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "graph_transformer": { - "attn_dropout_recompute": attn_dropout, - "gelu_recompute": gelu, - "transformer_layer_recompute": transformer_layer, - "number_recompute_layers": number_layers, - }, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _test_commons._load_pytorch_transformer_model( - device - ) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu()) - - # Compare loss to ground truth computed from current ORTTrainer API - assert trainer._onnx_model is not None - opset = get_model_opset(trainer._onnx_model) - _test_helpers.assert_model_outputs(expected_loss[opset], actual_loss, True, rtol=rtol) - - -@pytest.mark.parametrize( - "seed,device,gradient_accumulation_steps,total_steps,expected_loss", - [ - ( - 0, - "cuda", - 1, - 12, - [ - 10.5368022919, - 10.4146203995, - 10.3635568619, - 10.2650547028, - 10.2284049988, - 10.1304626465, - 10.0853414536, - 9.9987659454, - 9.9472427368, - 9.8832416534, - 9.8223171234, - 9.8222122192, - ], - ), - ( - 42, - "cuda", - 3, - 12, - [ - 10.6455879211, - 10.6247081757, - 10.6361322403, - 10.5187482834, - 10.5345087051, - 10.5487670898, - 10.4833698273, - 10.4600019455, - 10.4535751343, - 10.3774127960, - 10.4144191742, - 10.3757553101, - ], - ), - ( - 123, - "cuda", - 7, - 12, - [ - 10.5353469849, - 10.5261383057, - 10.5240392685, - 10.5013713837, - 10.5678377151, - 10.5452117920, - 10.5184345245, - 10.4271221161, - 10.4458627701, - 10.4864749908, - 10.4416503906, - 10.4467563629, - ], - ), - ( - 321, - "cuda", - 12, - 12, - [ - 10.5773944855, - 10.5428829193, - 10.5974750519, - 10.5416746140, - 10.6009902954, - 10.5684127808, - 10.5759754181, - 10.5636739731, - 10.5613927841, - 10.5825119019, - 10.6031589508, - 10.6199369431, - ], - ), - ], -) -def testORTTrainerGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps, expected_loss): - return # TODO: re-enable after nondeterminism on backend is fixed. update numbers - rtol = 1e-3 - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu()) - - # Compare legacy vs experimental APIs - _test_helpers.assert_model_outputs(expected_loss, actual_loss, rtol=rtol) - - -@pytest.mark.parametrize( - "dynamic_axes", - [ - (True), - (False), - ], -) -def testORTTrainerDynamicShape(dynamic_axes): - # Common setup - device = "cuda" - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions({}) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model( - device, dynamic_axes=dynamic_axes - ) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - total_steps = 10 - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - if dynamic_axes: - # Forcing batches with different sizes to exercise dynamic shapes - data = data[: -(i + 1)] - targets = targets[: -(i + 1) * data.size(1)] - _, _ = trainer.train_step(data, targets) - - assert trainer._onnx_model is not None - - -@pytest.mark.parametrize( - "enable_onnx_contrib_ops", - [ - (True), - (False), - ], -) -def testORTTrainerInternalUseContribOps(enable_onnx_contrib_ops): - # Common setup - device = "cuda" - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions({"_internal_use": {"enable_onnx_contrib_ops": enable_onnx_contrib_ops}}) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - data, targets = batcher_fn(train_data, 0) - if not enable_onnx_contrib_ops and not pytorch_110: - with pytest.raises(Exception): # noqa: B017 - _, _ = trainer.train_step(data, targets) - else: - _, _ = trainer.train_step(data, targets) - - -@pytest.mark.parametrize( - "model_params", - [ - ( - [ - "decoder.weight", - "transformer_encoder.layers.0.linear1.bias", - "transformer_encoder.layers.0.linear2.weight", - "transformer_encoder.layers.1.self_attn.out_proj.weight", - "transformer_encoder.layers.1.self_attn.out_proj.bias", - ] - ), - ], -) -def testORTTrainerFrozenWeights(model_params): - # Common setup - device = "cuda" - total_steps = 10 - - # Setup ORTTrainer WITHOUT frozen weights - options = orttrainer.ORTTrainerOptions({}) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - _, _ = trainer.train_step(data, targets) - - # All model_params must be in the session state - assert trainer._onnx_model is not None - session_state = trainer._training_session.get_state() - assert all([param in session_state for param in model_params]) - - # Setup ORTTrainer WITH frozen weights - options = orttrainer.ORTTrainerOptions({"utils": {"frozen_weights": model_params}}) - model, _, _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - _, _ = trainer.train_step(data, targets) - - # All model_params CANNOT be in the session state - assert trainer._onnx_model is not None - session_state = trainer._training_session.get_state() - assert not all([param in session_state for param in model_params]) - - -@pytest.mark.parametrize( - "loss_scaler, optimizer_config, gradient_accumulation_steps", - [ - (None, optim.AdamConfig(), 1), - (None, optim.LambConfig(), 1), - (None, optim.SGDConfig(), 1), - (amp.DynamicLossScaler(), optim.AdamConfig(), 1), - (amp.DynamicLossScaler(), optim.LambConfig(), 5), - # (amp.DynamicLossScaler(), optim.SGDConfig(), 1), # SGD doesnt support fp16 - ], -) -def testORTTrainerStateDictWrapModelLossFn(loss_scaler, optimizer_config, gradient_accumulation_steps): - # Common setup - seed = 1 - - class LinearModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 4) - - def forward(self, y=None, x=None): - if y is not None: - return self.linear(x) + y - else: - return self.linear(x) + torch.ones(2, 4) - - model_desc = { - "inputs": [ - ("x", [2, 2]), - ( - "label", - [ - 2, - ], - ), - ], - "outputs": [("loss", [], True), ("output", [2, 4])], - } - - # Dummy data - data1 = torch.randn(2, 2) - label1 = torch.tensor([0, 1], dtype=torch.int64) - data2 = torch.randn(2, 2) - label2 = torch.tensor([0, 1], dtype=torch.int64) - - # Setup training based on test parameters - opts = { - "debug": {"deterministic_compute": True}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - } - if loss_scaler: - opts["mixed_precision"] = {"enabled": True, "loss_scaler": loss_scaler} - opts = orttrainer.ORTTrainerOptions(opts) - - # Training session 1 - torch.manual_seed(seed) - set_seed(seed) - pt_model = LinearModel() - - def loss_fn(x, label): - return F.nll_loss(F.log_softmax(x, dim=1), label) - - trainer = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts) - - # Check state_dict keys before train. Must be empty - state_dict = trainer.state_dict() - assert state_dict == {} - - # Train once and check initial state - trainer.train_step(x=data1, label=label1) - state_dict = trainer.state_dict() - assert all([weight in state_dict["model"]["full_precision"] for weight in ["linear.bias", "linear.weight"]]) - - # Initialize training session 2 from state of Training 1 - torch.manual_seed(seed) - set_seed(seed) - trainer2 = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts) - trainer2.load_state_dict(state_dict) - - # Verify state was loaded properly - _test_commons.assert_all_states_close_ort(state_dict, trainer2._load_state_dict.args[0]) - - # Perform a second step in both training session 1 and 2 and verify they match - trainer.train_step(x=data2, label=label2) - state_dict = trainer.state_dict() - trainer2.train_step(x=data2, label=label2) - state_dict2 = trainer2.state_dict() - _test_commons.assert_all_states_close_ort(state_dict, state_dict2) - - -def testORTTrainerNonPickableModel(): - # Common setup - import threading - - seed = 1 - - class UnpickableModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 4) - self._lock = threading.Lock() - - def forward(self, y=None, x=None): - with self._lock: - if y is not None: - return self.linear(x) + y - else: - return self.linear(x) + torch.ones(2, 4) - - model_desc = { - "inputs": [ - ("x", [2, 2]), - ( - "label", - [ - 2, - ], - ), - ], - "outputs": [("loss", [], True), ("output", [2, 4])], - } - - # Dummy data - data = torch.randn(2, 2) - label = torch.tensor([0, 1], dtype=torch.int64) - - # Setup training based on test parameters - opts = orttrainer.ORTTrainerOptions({"debug": {"deterministic_compute": True}}) - - # Training session - torch.manual_seed(seed) - set_seed(seed) - pt_model = UnpickableModel() - - def loss_fn(x, label): - return F.nll_loss(F.log_softmax(x, dim=1), label) - - optim_config = optim.AdamConfig() - trainer = orttrainer.ORTTrainer(pt_model, model_desc, optim_config, loss_fn=loss_fn, options=opts) - - # Train must succeed despite warning - _, _ = trainer.train_step(data, label) - - -############################################################################### -# Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############ -############################################################################### - - -@pytest.mark.parametrize("seed,device", [(1234, "cuda")]) -def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device): - # Common data - rtol = 1e-7 - total_steps = 5 - - # Setup for the experimental ORTTRainer run - torch.manual_seed(seed) - set_seed(seed) - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - # Training loop - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - _ = trainer.train_step(data, targets) - - # Setup for the legacy ORTTrainer run - torch.manual_seed(seed) - set_seed(seed) - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer( - model, my_loss, model_desc, "LambOptimizer", None, lr_desc, device, _use_deterministic_compute=True - ) - # Training loop - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - _, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) - - # Compare legacy vs experimental APIs - _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=rtol) - - -@pytest.mark.parametrize( - "seed,device", - [ - (321, "cuda"), - ], -) -def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device): - # Common data - total_steps = 128 - - # Setup experimental API - torch.manual_seed(seed) - set_seed(seed) - loss_scaler = amp.DynamicLossScaler() - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, - "debug": { - "deterministic_compute": True, - }, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - experimental_loss = [] - experimental_preds_dtype = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - exp_loss, exp_preds = trainer.train_step(data, targets) - experimental_loss.append(exp_loss.cpu()) - experimental_preds_dtype.append(exp_preds.dtype) - - # Setup legacy API - torch.manual_seed(seed) - set_seed(seed) - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - loss_scaler = Legacy_LossScaler("ort_test_input_loss_scalar", True) - legacy_trainer = Legacy_ORTTrainer( - model, - my_loss, - model_desc, - "LambOptimizer", - None, - lr_desc, - device=device, - _use_deterministic_compute=True, - use_mixed_precision=True, - loss_scaler=loss_scaler, - ) - # Training loop - legacy_loss = [] - legacy_preds_dtype = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) - legacy_loss.append(leg_loss.cpu()) - legacy_preds_dtype.append(leg_preds.dtype) - - # Compare legacy vs experimental APIs - assert experimental_preds_dtype == legacy_preds_dtype - _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer) - _test_helpers.assert_model_outputs(legacy_loss, experimental_loss) - - -@pytest.mark.parametrize( - "seed,device,gradient_accumulation_steps,total_steps", - [ - (0, "cuda", 1, 12), - (42, "cuda", 3, 12), - (123, "cuda", 7, 12), - (321, "cuda", 12, 12), - ], -) -def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps): - # Common data - torch.set_printoptions(precision=10) - - # Setup experimental API - torch.manual_seed(seed) - set_seed(seed) - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - experimental_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - exp_loss, _ = trainer.train_step(data, targets) - experimental_loss.append(exp_loss.cpu()) - - # Setup legacy API - torch.manual_seed(seed) - set_seed(seed) - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer( - model, - my_loss, - model_desc, - "LambOptimizer", - None, - lr_desc, - device=device, - _use_deterministic_compute=True, - gradient_accumulation_steps=gradient_accumulation_steps, - ) - # Training loop - legacy_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - leg_loss, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) - legacy_loss.append(leg_loss.cpu()) - - # Compare legacy vs experimental APIs - _test_helpers.assert_model_outputs(legacy_loss, experimental_loss) - - -@pytest.mark.parametrize( - "seed,device,optimizer_config,lr_scheduler, get_lr_this_step", - [ - ( - 0, - "cuda", - optim.AdamConfig, - optim.lr_scheduler.ConstantWarmupLRScheduler, - _test_commons.legacy_constant_lr_scheduler, - ), - ( - 0, - "cuda", - optim.LambConfig, - optim.lr_scheduler.ConstantWarmupLRScheduler, - _test_commons.legacy_constant_lr_scheduler, - ), - ( - 0, - "cuda", - optim.SGDConfig, - optim.lr_scheduler.ConstantWarmupLRScheduler, - _test_commons.legacy_constant_lr_scheduler, - ), - ( - 42, - "cuda", - optim.AdamConfig, - optim.lr_scheduler.LinearWarmupLRScheduler, - _test_commons.legacy_linear_lr_scheduler, - ), - ( - 42, - "cuda", - optim.LambConfig, - optim.lr_scheduler.LinearWarmupLRScheduler, - _test_commons.legacy_linear_lr_scheduler, - ), - ( - 42, - "cuda", - optim.SGDConfig, - optim.lr_scheduler.LinearWarmupLRScheduler, - _test_commons.legacy_linear_lr_scheduler, - ), - ( - 123, - "cuda", - optim.AdamConfig, - optim.lr_scheduler.CosineWarmupLRScheduler, - _test_commons.legacy_cosine_lr_scheduler, - ), - ( - 123, - "cuda", - optim.LambConfig, - optim.lr_scheduler.CosineWarmupLRScheduler, - _test_commons.legacy_cosine_lr_scheduler, - ), - ( - 123, - "cuda", - optim.SGDConfig, - optim.lr_scheduler.CosineWarmupLRScheduler, - _test_commons.legacy_cosine_lr_scheduler, - ), - ( - 321, - "cuda", - optim.AdamConfig, - optim.lr_scheduler.PolyWarmupLRScheduler, - _test_commons.legacy_poly_lr_scheduler, - ), - ( - 321, - "cuda", - optim.LambConfig, - optim.lr_scheduler.PolyWarmupLRScheduler, - _test_commons.legacy_poly_lr_scheduler, - ), - ( - 321, - "cuda", - optim.SGDConfig, - optim.lr_scheduler.PolyWarmupLRScheduler, - _test_commons.legacy_poly_lr_scheduler, - ), - ], -) -def testORTTrainerLegacyAndExperimentalLRScheduler(seed, device, optimizer_config, lr_scheduler, get_lr_this_step): - # Common data - total_steps = 10 - lr = 0.001 - warmup = 0.5 - cycles = 0.5 - power = 1.0 - lr_end = 1e-7 - torch.set_printoptions(precision=10) - - # Setup experimental API - torch.manual_seed(seed) - set_seed(seed) - if ( - lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler - or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler - ): - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) - elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) - elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) - else: - raise RuntimeError("Invalid lr_scheduler") - - options = orttrainer.ORTTrainerOptions( - {"device": {"id": device}, "debug": {"deterministic_compute": True}, "lr_scheduler": lr_scheduler} - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optimizer_config(lr=lr) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - experimental_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - exp_loss, exp_preds = trainer.train_step(data, targets) - experimental_loss.append(exp_loss.cpu()) - - # Setup legacy API - torch.manual_seed(seed) - set_seed(seed) - - if optimizer_config == optim.AdamConfig: - legacy_optimizer_config = "AdamOptimizer" - elif optimizer_config == optim.LambConfig: - legacy_optimizer_config = "LambOptimizer" - elif optimizer_config == optim.SGDConfig: - legacy_optimizer_config = "SGDOptimizer" - else: - raise RuntimeError("Invalid optimizer_config") - - if ( - get_lr_this_step == _test_commons.legacy_constant_lr_scheduler - or get_lr_this_step == _test_commons.legacy_linear_lr_scheduler - ): - get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup) - elif get_lr_this_step == _test_commons.legacy_cosine_lr_scheduler: - get_lr_this_step = partial( - get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, cycles=cycles - ) - elif get_lr_this_step == _test_commons.legacy_poly_lr_scheduler: - get_lr_this_step = partial( - get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end - ) - else: - raise RuntimeError("Invalid get_lr_this_step") - - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer( - model, - my_loss, - model_desc, - legacy_optimizer_config, - None, - lr_desc, - device=device, - _use_deterministic_compute=True, - get_lr_this_step=get_lr_this_step, - ) - # Training loop - legacy_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - leg_loss, leg_preds = legacy_trainer.train_step(data, targets) - legacy_loss.append(leg_loss.cpu()) - - # Compare legacy vs experimental APIs - _test_helpers.assert_model_outputs(legacy_loss, experimental_loss) - - -def testLossScalerLegacyAndExperimentalFullCycle(): - orttrainer.TrainStepInfo( - optimizer_config=optim.LambConfig(lr=0.001), all_finite=True, fetches=[], optimization_step=0, step=0 - ) - new_ls = amp.DynamicLossScaler() - old_ls = Legacy_LossScaler("ort_test_input_loss_scaler", True) - - # Initial state - train_step_info = orttrainer.TrainStepInfo(optim.LambConfig()) - assert_allclose(new_ls.loss_scale, old_ls.loss_scale_) - assert new_ls.up_scale_window == old_ls.up_scale_window_ - assert_allclose(new_ls.min_loss_scale, old_ls.min_loss_scale_) - assert_allclose(new_ls.max_loss_scale, old_ls.max_loss_scale_) - - # Performing 9*2000 updates to cover all branches of LossScaler.update(train_step_info.all_finite=True) - for _cycles in range(1, 10): - # 1999 updates without overflow produces 1999 stable steps - for _i in range(1, 2000): - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - # 2000th update without overflow doubles the loss and zero stable steps until max_loss_scale is reached - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - # After 8 cycles, loss scale should be float(1 << 16)*(2**8) - assert_allclose(new_loss_scale, old_loss_scale) - - # After 9 cycles, loss scale reaches max_loss_scale and it is not doubled from that point on - for _count in range(1, 2050): - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - # Setting train_step_info.all_finite = False to test down scaling - train_step_info.all_finite = False - - # Performing 24 updates to half the loss scale each time - for _count in range(1, 25): - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - # After 24 updates with gradient overflow, loss scale is 1.0 - assert_allclose(new_loss_scale, old_loss_scale) - - # After 25 updates, min_loss_scale is reached and loss scale is not halfed from that point on - for _count in range(1, 5): - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - -def testLossScalerLegacyAndExperimentalRandomAllFinite(): - new_ls = amp.DynamicLossScaler() - old_ls = Legacy_LossScaler("ort_test_input_loss_scaler", True) - - # Initial state - train_step_info = orttrainer.TrainStepInfo(optim.LambConfig()) - assert_allclose(new_ls.loss_scale, old_ls.loss_scale_) - assert new_ls.up_scale_window == old_ls.up_scale_window_ - assert_allclose(new_ls.min_loss_scale, old_ls.min_loss_scale_) - assert_allclose(new_ls.max_loss_scale, old_ls.max_loss_scale_) - - import random - - out = [] - for _ in range(1, 64): - train_step_info.all_finite = bool(random.getrandbits(1)) - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - out.append(new_loss_scale) - assert new_loss_scale > 1e-7 - - -def testORTTrainerRunSymbolicShapeInfer(): - # Common data - seed = 0 - total_steps = 12 - device = "cuda" - torch.set_printoptions(precision=10) - - # Setup without symbolic shape inference - torch.manual_seed(seed) - set_seed(seed) - options = orttrainer.ORTTrainerOptions({"device": {"id": device}, "debug": {"deterministic_compute": True}}) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - expected_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - expected_loss.append(loss.cpu()) - - # Setup with symbolic shape inference - torch.manual_seed(seed) - set_seed(seed) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - options.utils.run_symbolic_shape_infer = True - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - new_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - new_loss.append(loss.cpu()) - - # Setup with symbolic shape inference in legacy API - torch.manual_seed(seed) - set_seed(seed) - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer( - model, - my_loss, - model_desc, - "LambOptimizer", - None, - lr_desc, - device=device, - run_symbolic_shape_infer=True, - _use_deterministic_compute=True, - ) - # Training loop - legacy_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) - legacy_loss.append(loss.cpu()) - - # Compare losses - _test_helpers.assert_model_outputs(new_loss, expected_loss) - _test_helpers.assert_model_outputs(legacy_loss, expected_loss) - - -@pytest.mark.parametrize( - "test_input", - [ - ( - { - "distributed": {"enable_adasum": True}, - } - ) - ], -) -def testORTTrainerOptionsEnabledAdasumFlag(test_input): - """Test the enabled_adasum flag values when set enabled""" - - actual_values = orttrainer_options.ORTTrainerOptions(test_input) - assert actual_values.distributed.enable_adasum is True - - -@pytest.mark.parametrize( - "test_input", - [ - ( - { - "distributed": {"enable_adasum": False}, - } - ) - ], -) -def testORTTrainerOptionsDisabledAdasumFlag(test_input): - """Test the enabled_adasum flag values when set disabled""" - - actual_values = orttrainer_options.ORTTrainerOptions(test_input) - assert actual_values.distributed.enable_adasum is False - - -def testORTTrainerUnusedInput(): - class UnusedInputModel(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.mean(x) - - model = UnusedInputModel() - model_desc = {"inputs": [("x", [1]), ("y", [1])], "outputs": [("loss", [], True)]} - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config) - - # Run just one step to make sure there are no iobinding errors for the unused input. - try: - trainer.train_step(torch.FloatTensor([1.0]), torch.FloatTensor([1.0])) - except RuntimeError: - pytest.fail("RuntimeError doing train_step with an unused input.") - - -@pytest.mark.parametrize( - "debug_files", - [ - { - "model_after_graph_transforms_path": "transformed.onnx", - "model_with_gradient_graph_path": "transformed_grad.onnx", - "model_with_training_graph_path": "training.onnx", - "model_with_training_graph_after_optimization_path": "training_optimized.onnx", - }, - {"model_after_graph_transforms_path": "transformed.onnx", "model_with_training_graph_path": ""}, - ], -) -def testTrainingGraphExport(debug_files): - device = "cuda" - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - - with tempfile.TemporaryDirectory() as tempdir: - debug_paths = {} - for k, v in debug_files.items(): - debug_paths[k] = os.path.join(tempdir, v) - opts = orttrainer.ORTTrainerOptions({"device": {"id": device}, "debug": {"graph_save_paths": debug_paths}}) - optim_config = optim.AdamConfig() - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - data, targets = batcher_fn(train_data, 0) - trainer.train_step(data, targets) - for k, v in debug_files.items(): - path = debug_paths[k] - if len(v) > 0: - assert os.path.isfile(path) - saved_graph = onnx.load(path).graph - if k == "model_with_training_graph_path": - assert any("AdamOptimizer" in n.op_type for n in saved_graph.node) - elif k == "model_with_gradient_graph_path": - assert any("Grad" in n.name for n in saved_graph.node) - elif k == "model_after_graph_transforms_path": - assert any("LayerNormalization" in n.op_type for n in saved_graph.node) - elif k == "model_with_training_graph_after_optimization_path": - assert any("FusedMatMul" in n.op_type for n in saved_graph.node) - # remove saved file - os.remove(path) - else: - assert not os.path.isfile(path) - - -def _adam_max_norm_clip_data(): - device_capability_major = torch.cuda.get_device_capability()[0] - if device_capability_major == 7: # V100 for Dev machine - return [ - ( - 0, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.592951, - 10.067989, - 9.619152, - 9.245731, - 8.881137, - 8.578644, - 8.280573, - 8.063023, - 7.797933, - 7.486215, - 7.233806, - 7.011791, - ], - 14: [ - 10.584141, - 10.068119, - 9.581743, - 9.191472, - 8.880169, - 8.5352, - 8.311425, - 8.061202, - 7.773032, - 7.523009, - 7.258711, - 7.02805, - ], - }, - ), - ( - 0, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.592951, - 10.068722, - 9.620503, - 9.247791, - 8.883972, - 8.582286, - 8.285027, - 8.068308, - 7.803638, - 7.492318, - 7.240352, - 7.018665, - ], - 14: [ - 10.584141, - 10.068845, - 9.583107, - 9.193537, - 8.882966, - 8.538839, - 8.315872, - 8.066408, - 7.778978, - 7.529708, - 7.265849, - 7.035439, - ], - }, - ), - ( - 42, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.647908, - 10.144501, - 9.672352, - 9.306980, - 8.956026, - 8.602655, - 8.351079, - 8.088144, - 7.867220, - 7.564082, - 7.289846, - 7.073726, - ], - 14: [ - 10.697515, - 10.229034, - 9.765422, - 9.428294, - 9.080612, - 8.715208, - 8.459574, - 8.169073, - 7.940211, - 7.654147, - 7.390446, - 7.166227, - ], - }, - ), - ( - 42, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.647908, - 10.145191, - 9.673690, - 9.309031, - 8.959020, - 8.606632, - 8.355836, - 8.093478, - 7.873327, - 7.570731, - 7.296772, - 7.0809422, - ], - 14: [ - 10.697515, - 10.22967, - 9.766556, - 9.430037, - 9.083106, - 8.718601, - 8.463726, - 8.17396, - 7.945755, - 7.660188, - 7.396963, - 7.172944, - ], - }, - ), - ] - elif device_capability_major == 5: # M60 for CI machines (Python Packaging Pipeline) - return [ - ( - 0, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.618382, - 10.08292, - 9.603334, - 9.258133, - 8.917768, - 8.591574, - 8.318401, - 8.042292, - 7.783608, - 7.50226, - 7.236041, - 7.035602, - ], - 14: [ - 10.618382, - 10.08292, - 9.603334, - 9.258133, - 8.917768, - 8.591574, - 8.318401, - 8.042292, - 7.783608, - 7.50226, - 7.236041, - 7.035602, - ], - }, - ), - ( - 0, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.618382, - 10.083632, - 9.604639, - 9.260109, - 8.920504, - 8.595082, - 8.322799, - 8.047493, - 7.78929, - 7.508382, - 7.242587, - 7.042367, - ], - 14: [ - 10.618382, - 10.083632, - 9.604639, - 9.260109, - 8.920504, - 8.595082, - 8.322799, - 8.047493, - 7.78929, - 7.508382, - 7.242587, - 7.042367, - ], - }, - ), - ( - 42, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.68639, - 10.102986, - 9.647681, - 9.293091, - 8.958928, - 8.625297, - 8.351107, - 8.079577, - 7.840723, - 7.543044, - 7.284141, - 7.072688, - ], - 14: [ - 10.68639, - 10.102986, - 9.647681, - 9.293091, - 8.958928, - 8.625297, - 8.351107, - 8.079577, - 7.840723, - 7.543044, - 7.284141, - 7.072688, - ], - }, - ), - ( - 42, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.68639, - 10.103672, - 9.649025, - 9.295167, - 8.961777, - 8.629059, - 8.355571, - 8.084871, - 7.846589, - 7.549438, - 7.290722, - 7.079446, - ], - 14: [ - 10.697515, - 10.22967, - 9.766556, - 9.430037, - 9.083106, - 8.718601, - 8.463726, - 8.17396, - 7.945755, - 7.660188, - 7.396963, - 7.172944, - ], - }, - ), - ] - - -@pytest.mark.parametrize( - "seed,device,max_norm_clip,gradient_accumulation_steps,total_steps,expected_loss", _adam_max_norm_clip_data() -) -def testORTTrainerAdamMaxNormClip(seed, device, max_norm_clip, gradient_accumulation_steps, total_steps, expected_loss): - rtol = 1e-5 - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.AdamConfig(lr=0.001, max_norm_clip=max_norm_clip) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu().item()) - - # Compare legacy vs experimental APIs - assert trainer._onnx_model is not None - opset = get_model_opset(trainer._onnx_model) - _test_helpers.assert_model_outputs(expected_loss[opset], actual_loss, rtol=rtol) - - -def _lamb_max_norm_clip_data(): - device_capability_major = torch.cuda.get_device_capability()[0] - if device_capability_major == 7: # V100 for Dev machine - return [ - ( - 0, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.592951, - 10.487728, - 10.422251, - 10.350913, - 10.244248, - 10.213003, - 10.129222, - 10.095112, - 10.035983, - 9.974586, - 9.909771, - 9.874278, - ], - 14: [ - 10.584141, - 10.497192, - 10.389251, - 10.286045, - 10.231354, - 10.17018, - 10.066779, - 10.048138, - 9.958029, - 9.8908, - 9.82965, - 9.755484, - ], - }, - ), - ( - 0, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.592951, - 10.452503, - 10.349832, - 10.245314, - 10.106587, - 10.046009, - 9.934781, - 9.875164, - 9.792067, - 9.704592, - 9.617104, - 9.563070, - ], - 14: [ - 10.584141, - 10.461154, - 10.315399, - 10.178979, - 10.092329, - 9.999928, - 9.869949, - 9.824564, - 9.707565, - 9.61643, - 9.532847, - 9.439593, - ], - }, - ), - ( - 42, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.647908, - 10.566276, - 10.476154, - 10.406275, - 10.311079, - 10.240053, - 10.196469, - 10.113955, - 10.117376, - 10.013077, - 9.930301, - 9.893368, - ], - 14: [ - 10.697515, - 10.631279, - 10.528757, - 10.496689, - 10.411219, - 10.322109, - 10.297314, - 10.215549, - 10.149698, - 10.087336, - 10.010884, - 9.934544, - ], - }, - ), - ( - 42, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.647908, - 10.531957, - 10.405246, - 10.302971, - 10.176583, - 10.075583, - 10.005772, - 9.897825, - 9.875748, - 9.748932, - 9.642885, - 9.586762, - ], - 14: [ - 10.697515, - 10.596729, - 10.457815, - 10.393475, - 10.277581, - 10.158909, - 10.108126, - 10.000326, - 9.912526, - 9.826057, - 9.727899, - 9.633768, - ], - }, - ), - ] - elif device_capability_major == 5: # M60 for CI machines (Python Packaging Pipeline) - return [ - ( - 0, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.618382, - 10.50222, - 10.403347, - 10.35298, - 10.288447, - 10.237399, - 10.184225, - 10.089048, - 10.008952, - 9.972644, - 9.897674, - 9.84524, - ], - 14: [0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], - }, - ), - ( - 0, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.618382, - 10.466732, - 10.330871, - 10.24715, - 10.150972, - 10.069127, - 9.98974, - 9.870169, - 9.763693, - 9.704323, - 9.605957, - 9.533117, - ], - 14: [1, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], - }, - ), - ( - 42, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.68639, - 10.511692, - 10.447308, - 10.405255, - 10.334866, - 10.261473, - 10.169422, - 10.107138, - 10.069889, - 9.97798, - 9.928105, - 9.896435, - ], - 14: [2, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], - }, - ), - ( - 42, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.68639, - 10.477489, - 10.376671, - 10.301725, - 10.200718, - 10.098477, - 9.97995, - 9.890104, - 9.828899, - 9.713555, - 9.639567, - 9.589856, - ], - 14: [3, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], - }, - ), - ] - - -@pytest.mark.parametrize( - "seed,device,max_norm_clip, gradient_accumulation_steps,total_steps,expected_loss", _lamb_max_norm_clip_data() -) -def testORTTrainerLambMaxNormClip(seed, device, max_norm_clip, gradient_accumulation_steps, total_steps, expected_loss): - rtol = 1e-3 - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001, max_norm_clip=max_norm_clip) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu().item()) - - # Compare legacy vs experimental APIs - opset = get_model_opset(trainer._onnx_model) - _test_helpers.assert_model_outputs(expected_loss[opset], actual_loss, rtol=rtol) diff --git a/orttraining/orttraining/test/python/orttraining_test_transformers.py b/orttraining/orttraining/test/python/orttraining_test_transformers.py deleted file mode 100644 index dbaf4a293c466..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_transformers.py +++ /dev/null @@ -1,480 +0,0 @@ -import random -import unittest - -import numpy as np -import torch -from numpy.testing import assert_allclose -from orttraining_test_data_loader import BatchArgsOption, ids_tensor -from orttraining_test_utils import get_lr, run_test -from transformers import BertConfig, BertForPreTraining - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer # noqa: F401 - - -class BertModelTest(unittest.TestCase): - class BertModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=True, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=5, - num_attention_heads=4, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - scope=None, - device="cpu", - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.scope = scope - self.device = device - - # 1. superset of bert input/output descs - # see BertPreTrainedModel doc - self.input_ids_desc = IODescription( - "input_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.vocab_size - ) - self.attention_mask_desc = IODescription( - "attention_mask", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2 - ) - self.token_type_ids_desc = IODescription( - "token_type_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2 - ) - self.position_ids_desc = IODescription( - "position_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.max_position_embeddings - ) - self.head_mask_desc = IODescription( - "head_mask", [self.num_hidden_layers, self.num_attention_heads], torch.int64, num_classes=2 - ) - self.inputs_embeds_desc = IODescription( - "inputs_embeds", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32 - ) - - self.encoder_hidden_states_desc = IODescription( - "encoder_hidden_states", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32 - ) - self.encoder_attention_mask_desc = IODescription( - "encoder_attention_mask", ["batch", "max_seq_len_in_batch"], torch.float32 - ) - - # see BertForPreTraining doc - self.masked_lm_labels_desc = IODescription( - "masked_lm_labels", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.vocab_size - ) - self.next_sentence_label_desc = IODescription( - "next_sentence_label", - [ - "batch", - ], - torch.int64, - num_classes=2, - ) - - # outputs - self.loss_desc = IODescription( - "loss", - [ - 1, - ], - torch.float32, - ) - self.prediction_scores_desc = IODescription( - "prediction_scores", ["batch", "max_seq_len_in_batch", self.vocab_size], torch.float32 - ) - - self.seq_relationship_scores_desc = IODescription( - "seq_relationship_scores", ["batch", 2], torch.float32 - ) # IODescription('seq_relationship_scores', ['batch', 'max_seq_len_in_batch', 2], torch.float32) - self.hidden_states_desc = IODescription( - "hidden_states", - [self.num_hidden_layers, "batch", "max_seq_len_in_batch", self.hidden_size], - torch.float32, - ) - self.attentions_desc = IODescription( - "attentions", - [ - self.num_hidden_layers, - "batch", - self.num_attention_heads, - "max_seq_len_in_batch", - "max_seq_len_in_batch", - ], - torch.float32, - ) - self.last_hidden_state_desc = IODescription( - "last_hidden_state", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32 - ) - self.pooler_output_desc = IODescription("pooler_output", ["batch", self.hidden_size], torch.float32) - - def BertForPreTraining_descs(self): - return ModelDescription( - [ - self.input_ids_desc, - self.attention_mask_desc, - self.token_type_ids_desc, - self.masked_lm_labels_desc, - self.next_sentence_label_desc, - ], - # returns loss_desc if both masked_lm_labels_desc, next_sentence_label are provided - # hidden_states_desc, attentions_desc shall be included according to config.output_attentions, config.output_hidden_states - [ - self.loss_desc, - self.prediction_scores_desc, - self.seq_relationship_scores_desc, - # hidden_states_desc, attentions_desc - ], - ) - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).to(self.device) - - input_mask = None - if self.use_input_mask: - input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2).to(self.device) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size).to(self.device) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size).to(self.device) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels).to(self.device) - choice_labels = ids_tensor([self.batch_size], self.num_choices).to(self.device) - - config = BertConfig( - vocab_size=self.vocab_size, - vocab_size_or_config_json_file=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - ) - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def create_and_check_bert_for_pretraining( - self, - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - option_use_internal_get_lr_this_step=[True], # noqa: B006 - option_use_internal_loss_scaler=[True], # noqa: B006 - ): - seed = 42 - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - onnxruntime.set_seed(seed) - - model = BertForPreTraining(config=config) - model.eval() - loss, prediction_scores, seq_relationship_score = model( - input_ids, - attention_mask=input_mask, - token_type_ids=token_type_ids, - masked_lm_labels=token_labels, - next_sentence_label=sequence_labels, - ) - model_desc = ModelDescription( - [ - self.input_ids_desc, - self.attention_mask_desc, - self.token_type_ids_desc, - self.masked_lm_labels_desc, - self.next_sentence_label_desc, - ], - [self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc], - ) - - from collections import namedtuple - - MyArgs = namedtuple( - "MyArgs", "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len" - ) - - dataset_len = 100 - epochs = 8 - max_steps = epochs * dataset_len - args = MyArgs( - local_rank=0, - world_size=1, - max_steps=max_steps, - learning_rate=0.00001, - warmup_proportion=0.01, - batch_size=13, - seq_len=7, - ) - - def get_lr_this_step(global_step): - return get_lr(args, global_step) - - loss_scaler = LossScaler("loss_scale_input_name", True, up_scale_window=2000) - - for fp16 in option_fp16: - for allreduce_post_accumulation in option_allreduce_post_accumulation: - for gradient_accumulation_steps in option_gradient_accumulation_steps: - for use_internal_get_lr_this_step in option_use_internal_get_lr_this_step: - for use_internal_loss_scaler in option_use_internal_loss_scaler: - for split_batch in option_split_batch: - print("gradient_accumulation_steps:", gradient_accumulation_steps) - print("split_batch:", split_batch) - - seed = 42 - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - onnxruntime.set_seed(seed) - - ( - old_api_loss_ort, - old_api_prediction_scores_ort, - old_api_seq_relationship_score_ort, - ) = run_test( - model, - model_desc, - self.device, - args, - gradient_accumulation_steps, - fp16, - allreduce_post_accumulation, - get_lr_this_step, - use_internal_get_lr_this_step, - loss_scaler, - use_internal_loss_scaler, - split_batch, - dataset_len, - epochs, - use_new_api=False, - ) - - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - onnxruntime.set_seed(seed) - if use_internal_get_lr_this_step and use_internal_loss_scaler: - ( - new_api_loss_ort, - new_api_prediction_scores_ort, - new_api_seq_relationship_score_ort, - ) = run_test( - model, - model_desc, - self.device, - args, - gradient_accumulation_steps, - fp16, - allreduce_post_accumulation, - get_lr_this_step, - use_internal_get_lr_this_step, - loss_scaler, - use_internal_loss_scaler, - split_batch, - dataset_len, - epochs, - use_new_api=True, - ) - - assert_allclose(old_api_loss_ort, new_api_loss_ort) - assert_allclose(old_api_prediction_scores_ort, new_api_prediction_scores_ort) - assert_allclose( - old_api_seq_relationship_score_ort, new_api_seq_relationship_score_ort - ) - - def setUp(self): - self.model_tester = BertModelTest.BertModelTester(self) - - def test_for_pretraining_mixed_precision(self): - # It would be better to test both with/without mixed precision and allreduce_post_accumulation. - # However, stress test of all the 4 cases is not stable at least on the test machine. - # There we only test mixed precision and allreduce_post_accumulation because it is the most useful use cases. - option_fp16 = [True] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1] - option_split_batch = [BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_mixed_precision_with_gradient_accumulation(self): - # It would be better to test both with/without mixed precision and allreduce_post_accumulation. - # However, stress test of all the 4 cases is not stable at least on the test machine. - # There we only test mixed precision and allreduce_post_accumulation because it is the most useful use cases. - option_fp16 = [True] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [8] - option_split_batch = [BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_all(self): - # This test is not stable because it create and run ORTSession multiple times. - # It occasionally gets seg fault at ~MemoryPattern() - # when releasing patterns_. In order not to block PR merging CI test, - # this test is broke into following individual tests. - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1, 8] - option_split_batch = [BatchArgsOption.List, BatchArgsOption.Dict, BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_list_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1] - option_split_batch = [BatchArgsOption.List] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_dict_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1] - option_split_batch = [BatchArgsOption.Dict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_list_and_dict_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1] - option_split_batch = [BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_grad_accumulation_list_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [8] - option_split_batch = [BatchArgsOption.List] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_grad_accumulation_dict_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [8] - option_split_batch = [BatchArgsOption.Dict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_grad_accumulation_list_and_dict_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [8] - option_split_batch = [BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_test_utils.py b/orttraining/orttraining/test/python/orttraining_test_utils.py deleted file mode 100644 index 527cfb8a0ba7d..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_utils.py +++ /dev/null @@ -1,246 +0,0 @@ -import math - -import torch -from orttraining_test_data_loader import BatchArgsOption, create_ort_test_dataloader, split_batch - -from onnxruntime.capi.ort_trainer import IODescription, ORTTrainer -from onnxruntime.training import amp, optim, orttrainer -from onnxruntime.training.optim import _LRScheduler - - -def warmup_cosine(x, warmup=0.002): - if x < warmup: - return x / warmup - return 0.5 * (1.0 + torch.cos(math.pi * x)) - - -def warmup_constant(x, warmup=0.002): - if x < warmup: - return x / warmup - return 1.0 - - -def warmup_linear(x, warmup=0.002): - if x < warmup: - return x / warmup - return max((x - 1.0) / (warmup - 1.0), 0.0) - - -def warmup_poly(x, warmup=0.002, degree=0.5): - if x < warmup: - return x / warmup - return (1.0 - x) ** degree - - -SCHEDULES = { - "warmup_cosine": warmup_cosine, - "warmup_constant": warmup_constant, - "warmup_linear": warmup_linear, - "warmup_poly": warmup_poly, -} - - -def get_lr(args, training_steps, schedule="warmup_poly"): - if args.max_steps == -1: - return args.learning_rate - - schedule_fct = SCHEDULES[schedule] - return args.learning_rate * schedule_fct(training_steps / args.max_steps, args.warmup_proportion) - - -def map_optimizer_attributes(name): - no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] - no_decay = any(no_decay_key in name for no_decay_key in no_decay_keys) - if no_decay: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} - else: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} - - -class WrapLRScheduler(_LRScheduler): - def __init__(self, get_lr_this_step): - super().__init__() - self.get_lr_this_step = get_lr_this_step - - def get_lr(self, train_step_info): - return [self.get_lr_this_step(train_step_info.optimization_step)] - - -def run_test( - model, - model_desc, - device, - args, - gradient_accumulation_steps, - fp16, - allreduce_post_accumulation, - get_lr_this_step, - use_internal_get_lr_this_step, - loss_scaler, - use_internal_loss_scaler, - batch_args_option, - dataset_len, - epochs, - use_new_api, -): - dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device) - - if use_new_api: - assert use_internal_loss_scaler, "new api should always use internal loss scaler" - - new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step) - - new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None - options = orttrainer.ORTTrainerOptions( - { - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "device": {"id": device}, - "mixed_precision": {"enabled": fp16, "loss_scaler": new_api_loss_scaler}, - "debug": { - "deterministic_compute": True, - }, - "utils": {"grad_norm_clip": True}, - "distributed": {"allreduce_post_accumulation": True}, - "lr_scheduler": new_api_lr_scheduler, - } - ) - - param_optimizer = list(model.named_parameters()) - params = [ - { - "params": [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n], - "alpha": 0.9, - "beta": 0.999, - "lambda": 0.0, - "epsilon": 1e-6, - }, - { - "params": [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)], - "alpha": 0.9, - "beta": 0.999, - "lambda": 0.0, - "epsilon": 1e-6, - }, - ] - - vocab_size = 99 - new_model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", "max_seq_len_in_batch"], - ), - ( - "token_type_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "masked_lm_labels", - ["batch", "max_seq_len_in_batch"], - ), - ( - "next_sentence_label", - [ - "batch", - ], - ), - ], - "outputs": [ - ( - "loss", - [ - 1, - ], - True, - ), - ("prediction_scores", ["batch", "max_seq_len_in_batch", vocab_size]), - ("seq_relationship_scores", ["batch", 2]), - ], - } - - optim_config = optim.LambConfig(params=params, lr=2e-5) - model = orttrainer.ORTTrainer(model, new_model_desc, optim_config, options=options) - print("running with new frontend API") - else: - model = ORTTrainer( - model, - None, - model_desc, - "LambOptimizer", - map_optimizer_attributes=map_optimizer_attributes, - learning_rate_description=IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ), - device=device, - _enable_internal_postprocess=True, - gradient_accumulation_steps=gradient_accumulation_steps, - # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6 - world_rank=args.local_rank, - world_size=args.world_size, - use_mixed_precision=fp16, - allreduce_post_accumulation=allreduce_post_accumulation, - get_lr_this_step=get_lr_this_step if use_internal_get_lr_this_step else None, - loss_scaler=loss_scaler if use_internal_loss_scaler else None, - _opset_version=14, - _use_deterministic_compute=True, - ) - print("running with old frontend API") - - # training loop - eval_batch = None - if not use_new_api: - model.train() - for _epoch in range(epochs): - for step, batch in enumerate(dataloader): - if eval_batch is None: - eval_batch = batch - - if not use_internal_get_lr_this_step: - lr = get_lr_this_step(step) - learning_rate = torch.tensor([lr]) - - if not use_internal_loss_scaler and fp16: - loss_scale = torch.tensor([loss_scaler.loss_scale_]) - - if batch_args_option == BatchArgsOption.List: - if not use_internal_get_lr_this_step: - batch = [*batch, learning_rate] # noqa: PLW2901 - if not use_internal_loss_scaler and fp16: - batch = [*batch, loss_scale] # noqa: PLW2901 - outputs = model.train_step(*batch) - elif batch_args_option == BatchArgsOption.Dict: - args, kwargs = split_batch(batch, model_desc.inputs_, 0) - if not use_internal_get_lr_this_step: - kwargs["Learning_Rate"] = learning_rate - if not use_internal_loss_scaler and fp16: - kwargs[model.loss_scale_input_name] = loss_scale - outputs = model.train_step(*args, **kwargs) - else: - args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs - args, kwargs = split_batch(batch, model_desc.inputs_, args_count) - if not use_internal_get_lr_this_step: - kwargs["Learning_Rate"] = learning_rate - if not use_internal_loss_scaler and fp16: - kwargs[model.loss_scale_input_name] = loss_scale - outputs = model.train_step(*args, **kwargs) - - # eval - if batch_args_option == BatchArgsOption.List: - outputs = model.eval_step(*batch) - elif batch_args_option == BatchArgsOption.Dict: - args, kwargs = split_batch(batch, model_desc.inputs_, 0) - outputs = model.eval_step(*args, **kwargs) - else: - args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs - args, kwargs = split_batch(batch, model_desc.inputs_, args_count) - outputs = model.eval_step(*args, **kwargs) - - return (output.cpu().numpy() for output in outputs) diff --git a/orttraining/orttraining/test/python/orttraining_transformer_trainer.py b/orttraining/orttraining/test/python/orttraining_transformer_trainer.py deleted file mode 100644 index bce726871bacf..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_transformer_trainer.py +++ /dev/null @@ -1,357 +0,0 @@ -# adapted from Trainer.py of huggingface transformers - -import json -import logging -import os -import random -from typing import Callable, Dict, List, NamedTuple, Optional - -import numpy as np -import torch -from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import Dataset -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import SequentialSampler -from tqdm import tqdm, trange -from transformers.data.data_collator import DefaultDataCollator -from transformers.modeling_utils import PreTrainedModel -from transformers.training_args import TrainingArguments - -import onnxruntime -from onnxruntime.training import amp, optim, orttrainer - -try: - from torch.utils.tensorboard import SummaryWriter - - _has_tensorboard = True -except ImportError: - try: - from tensorboardX import SummaryWriter # noqa: F401 - - _has_tensorboard = True - except ImportError: - _has_tensorboard = False - - -def is_tensorboard_available(): - return _has_tensorboard - - -logger = logging.getLogger(__name__) - - -def set_seed(seed: int): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - onnxruntime.set_seed(seed) - - -class EvalPrediction(NamedTuple): - predictions: np.ndarray - label_ids: np.ndarray - - -class PredictionOutput(NamedTuple): - predictions: np.ndarray - label_ids: Optional[np.ndarray] - metrics: Optional[Dict[str, float]] - - -class TrainOutput(NamedTuple): - global_step: int - training_loss: float - - -def get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps, base_lr): - def lr_lambda_linear(current_step): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) - - def lambda_lr_get_lr(current_global_step): - # LambdaLR increment self.last_epoch at evert sept() - return base_lr * lr_lambda_linear(current_global_step) - - return lambda_lr_get_lr - - -class ORTTransformerTrainer: - """ """ - - model: PreTrainedModel - args: TrainingArguments - train_dataset: Dataset - eval_dataset: Dataset - compute_metrics: Callable[[EvalPrediction], Dict] - - def __init__( - self, - model: PreTrainedModel, - model_desc: dict, - args: TrainingArguments, - train_dataset: Dataset, - eval_dataset: Dataset, - compute_metrics: Callable[[EvalPrediction], Dict], - world_size: Optional[int] = 1, - ): - """ """ - - self.model = model - self.model_desc = model_desc - self.args = args - self.world_size = world_size - self.data_collator = DefaultDataCollator() - self.train_dataset = train_dataset - self.eval_dataset = eval_dataset - self.compute_metrics = compute_metrics - set_seed(self.args.seed) - # Create output directory if needed - if self.args.local_rank in [-1, 0]: - os.makedirs(self.args.output_dir, exist_ok=True) - - def get_train_dataloader(self) -> DataLoader: - if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") - train_sampler = ( - SequentialSampler(self.train_dataset) - if self.args.local_rank == -1 - else DistributedSampler(self.train_dataset) - ) - return DataLoader( - self.train_dataset, - batch_size=self.args.train_batch_size, - sampler=train_sampler, - collate_fn=self.data_collator.collate_batch, - ) - - def get_eval_dataloader(self) -> DataLoader: - return DataLoader( - self.eval_dataset, - batch_size=self.args.eval_batch_size, - shuffle=False, - collate_fn=self.data_collator.collate_batch, - ) - - def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: - # We use the same batch_size as for eval. - return DataLoader( - test_dataset, - batch_size=self.args.eval_batch_size, - shuffle=False, - collate_fn=self.data_collator.collate_batch, - ) - - def train(self): - """ - Main training entry point. - """ - train_dataloader = self.get_train_dataloader() - - if self.args.max_steps > 0: - t_total = self.args.max_steps - num_train_epochs = ( - self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1 - ) - else: - t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) - num_train_epochs = self.args.num_train_epochs - - lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(t_total, self.args.warmup_steps / float(t_total)) - - loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None - device = self.args.device.type - - device = f"{device}:{self.args.device.index}" if self.args.device.index else f"{device}:0" - options = orttrainer.ORTTrainerOptions( - { - "batch": {"gradient_accumulation_steps": self.args.gradient_accumulation_steps}, - "device": {"id": device}, - "mixed_precision": {"enabled": self.args.fp16, "loss_scaler": loss_scaler}, - "debug": { - "deterministic_compute": True, - }, - "utils": {"grad_norm_clip": False}, - "distributed": { - # we are running single node multi gpu test. thus world_rank = local_rank - # and world_size = self.args.n_gpu - "world_rank": max(0, self.args.local_rank), - "world_size": int(self.world_size), - "local_rank": max(0, self.args.local_rank), - "allreduce_post_accumulation": True, - }, - "lr_scheduler": lr_scheduler, - } - ) - - param_optimizer = list(self.model.named_parameters()) - params = [ - { - "params": [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n], - "weight_decay_mode": 1, - }, - { - "params": [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)], - "weight_decay_mode": 1, - }, - ] - - optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) - self.model = orttrainer.ORTTrainer(self.model, self.model_desc, optim_config, options=options) - - # Train! - logger.info("***** Running training *****") - logger.info(" Num examples = %d", len(train_dataloader.dataset)) - logger.info(" Num Epochs = %d", num_train_epochs) - logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size) - logger.info( - " Total train batch size (w. parallel, distributed & accumulation) = %d", - self.args.train_batch_size - * self.args.gradient_accumulation_steps - * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1), - ) - logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) - logger.info(" Total optimization steps = %d", t_total) - - global_step = 0 - epochs_trained = 0 - steps_trained_in_current_epoch = 0 - - tr_loss = 0.0 - logging_loss = 0.0 - train_iterator = trange( - epochs_trained, - int(num_train_epochs), - desc="Epoch", - disable=self.args.local_rank not in [-1, 0], - ) - - for _epoch in train_iterator: - epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0]) - for step, inputs in enumerate(epoch_iterator): - # Skip past any already trained steps if resuming training - if steps_trained_in_current_epoch > 0: - steps_trained_in_current_epoch -= 1 - continue - - tr_loss += self._training_step(self.model, inputs) - - if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( - len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator) - ): - global_step += 1 - - if self.args.local_rank in [-1, 0]: - if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or ( - global_step == 1 and self.args.logging_first_step - ): - logs = {} - if self.args.evaluate_during_training: - results = self.evaluate() - for key, value in results.items(): - eval_key = f"eval_{key}" - logs[eval_key] = value - - loss_scalar = (tr_loss - logging_loss) / self.args.logging_steps - - logs["loss"] = loss_scalar - logging_loss = tr_loss - - epoch_iterator.write(json.dumps({**logs, **{"step": global_step}})) - - if self.args.max_steps > 0 and global_step > self.args.max_steps: - epoch_iterator.close() - break - if self.args.max_steps > 0 and global_step > self.args.max_steps: - train_iterator.close() - break - - logger.info("\n\nTraining completed. \n\n") - return TrainOutput(global_step, tr_loss / global_step) - - def _training_step(self, model, inputs: Dict[str, torch.Tensor]) -> float: - for k, v in inputs.items(): - inputs[k] = v.to(self.args.device) - - outputs = model.train_step(**inputs) - loss = outputs[0] # model outputs are always tuple in transformers (see doc) - - return loss.item() - - def save_model(self, output_dir: Optional[str] = None): - output_dir = output_dir if output_dir is not None else self.args.output_dir - os.makedirs(output_dir, exist_ok=True) - self.model.save_as_onnx(os.path.join(output_dir, "transformer.onnx")) - - def evaluate(self) -> Dict[str, float]: - """ - Run evaluation and return metrics. - - Returns: - A dict containing: - - the eval loss - - the potential metrics computed from the predictions - """ - eval_dataloader = self.get_eval_dataloader() - - output = self._prediction_loop(eval_dataloader, description="Evaluation") - return output.metrics - - def predict(self, test_dataset: Dataset) -> PredictionOutput: - """ - Run prediction and return predictions and potential metrics. - - Depending on the dataset and your use case, your test dataset may contain labels. - In that case, this method will also return metrics, like in evaluate(). - """ - test_dataloader = self.get_test_dataloader(test_dataset) - return self._prediction_loop(test_dataloader, description="Prediction") - - def _prediction_loop(self, dataloader: DataLoader, description: str) -> PredictionOutput: - """ - Prediction/evaluation loop, shared by `evaluate()` and `predict()`. - - Works both with or without labels. - """ - - logger.info("***** Running %s *****", description) - logger.info(" Num examples = %d", len(dataloader.dataset)) - logger.info(" Batch size = %d", dataloader.batch_size) - eval_losses: List[float] = [] - preds: np.ndarray = None - label_ids: np.ndarray = None - - for inputs in tqdm(dataloader, desc=description): - has_labels = any(inputs.get(k) is not None for k in ["labels", "masked_lm_labels"]) - - for k, v in inputs.items(): - inputs[k] = v.to(self.args.device) - - with torch.no_grad(): - outputs = self.model.eval_step(**inputs) - - if has_labels: - step_eval_loss, logits = outputs[:2] - eval_losses += [step_eval_loss.mean().item()] - else: - logits = outputs[0] - - if preds is None: - preds = logits.detach().cpu().numpy() - else: - preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) - if inputs.get("labels") is not None: - if label_ids is None: - label_ids = inputs["labels"].detach().cpu().numpy() - else: - label_ids = np.append(label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) - - if self.compute_metrics is not None and preds is not None and label_ids is not None: - metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) - else: - metrics = {} - if len(eval_losses) > 0: - metrics["loss"] = np.mean(eval_losses) - - return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) diff --git a/orttraining/orttraining/test/python/utils_multiple_choice.py b/orttraining/orttraining/test/python/utils_multiple_choice.py deleted file mode 100644 index e0febaf2d6334..0000000000000 --- a/orttraining/orttraining/test/python/utils_multiple_choice.py +++ /dev/null @@ -1,269 +0,0 @@ -# adapted from run_multiple_choice.py of huggingface transformers -# https://github.com/huggingface/transformers/blob/master/examples/multiple-choice/utils_multiple_choice.py - -import csv -import glob # noqa: F401 -import json # noqa: F401 -import logging -import os -from dataclasses import dataclass -from enum import Enum -from typing import List, Optional - -import torch -import tqdm -from filelock import FileLock -from torch.utils.data.dataset import Dataset -from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available # noqa: F401 - -logger = logging.getLogger(__name__) - - -@dataclass(frozen=True) -class InputExample: - """ - A single training/test example for multiple choice - - Args: - example_id: Unique id for the example. - question: string. The untokenized text of the second sequence (question). - contexts: list of str. The untokenized text of the first sequence (context of corresponding question). - endings: list of str. multiple choice's options. Its length must be equal to contexts' length. - label: (Optional) string. The label of the example. This should be - specified for train and dev examples, but not for test examples. - """ - - example_id: str - question: str - contexts: List[str] - endings: List[str] - label: Optional[str] - - -@dataclass(frozen=True) -class InputFeatures: - """ - A single set of features of data. - Property names are the same names as the corresponding inputs to a model. - """ - - example_id: str - input_ids: List[List[int]] - attention_mask: Optional[List[List[int]]] - token_type_ids: Optional[List[List[int]]] - label: Optional[int] - - -class Split(Enum): - train = "train" - dev = "dev" - test = "test" - - -class DataProcessor: - """Base class for data converters for multiple choice data sets.""" - - def get_train_examples(self, data_dir): - """Gets a collection of `InputExample`s for the train set.""" - raise NotImplementedError() - - def get_dev_examples(self, data_dir): - """Gets a collection of `InputExample`s for the dev set.""" - raise NotImplementedError() - - def get_test_examples(self, data_dir): - """Gets a collection of `InputExample`s for the test set.""" - raise NotImplementedError() - - def get_labels(self): - """Gets the list of labels for this data set.""" - raise NotImplementedError() - - -class MultipleChoiceDataset(Dataset): - """ - This will be superseded by a framework-agnostic approach - soon. - """ - - features: List[InputFeatures] - - def __init__( - self, - data_dir: str, - tokenizer: PreTrainedTokenizer, - task: str, - processor: DataProcessor, - max_seq_length: Optional[int] = None, - overwrite_cache=False, - mode: Split = Split.train, - ): - cached_features_file = os.path.join( - data_dir, - "cached_{}_{}_{}_{}".format( - mode.value, - tokenizer.__class__.__name__, - str(max_seq_length), - task, - ), - ) - - # Make sure only the first process in distributed training processes the dataset, - # and the others will use the cache. - lock_path = cached_features_file + ".lock" - with FileLock(lock_path): - if os.path.exists(cached_features_file) and not overwrite_cache: - logger.info(f"Loading features from cached file {cached_features_file}") - self.features = torch.load(cached_features_file) - else: - logger.info(f"Creating features from dataset file at {data_dir}") - label_list = processor.get_labels() - if mode == Split.dev: - examples = processor.get_dev_examples(data_dir) - elif mode == Split.test: - examples = processor.get_test_examples(data_dir) - else: - examples = processor.get_train_examples(data_dir) - logger.info("Training examples: %s", len(examples)) - # TODO clean up all this to leverage built-in features of tokenizers - self.features = convert_examples_to_features( - examples, - label_list, - max_seq_length, - tokenizer, - pad_on_left=bool(tokenizer.padding_side == "left"), - pad_token=tokenizer.pad_token_id, - pad_token_segment_id=tokenizer.pad_token_type_id, - ) - logger.info("Saving features into cached file %s", cached_features_file) - torch.save(self.features, cached_features_file) - - def __len__(self): - return len(self.features) - - def __getitem__(self, i) -> InputFeatures: - return self.features[i] - - -class SwagProcessor(DataProcessor): - """Processor for the SWAG data set.""" - - def get_train_examples(self, data_dir): - """See base class.""" - logger.info(f"LOOKING AT {data_dir} train") - return self._create_examples(self._read_csv(os.path.join(data_dir, "train.csv")), "train") - - def get_dev_examples(self, data_dir): - """See base class.""" - logger.info(f"LOOKING AT {data_dir} dev") - return self._create_examples(self._read_csv(os.path.join(data_dir, "val.csv")), "dev") - - def get_test_examples(self, data_dir): - """See base class.""" - logger.info(f"LOOKING AT {data_dir} dev") - raise ValueError( - "For swag testing, the input file does not contain a label column. It can not be tested in current code" - "setting!" - ) - return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test") - - def get_labels(self): - """See base class.""" - return ["0", "1", "2", "3"] - - def _read_csv(self, input_file): - with open(input_file, encoding="utf-8") as f: - return list(csv.reader(f)) - - def _create_examples(self, lines: List[List[str]], type: str): - """Creates examples for the training and dev sets.""" - if type == "train" and lines[0][-1] != "label": - raise ValueError("For training, the input file must contain a label column.") - - examples = [ - InputExample( - example_id=line[2], - question=line[5], # in the swag dataset, the - # common beginning of each - # choice is stored in "sent2". - contexts=[line[4], line[4], line[4], line[4]], - endings=[line[7], line[8], line[9], line[10]], - label=line[11], - ) - for line in lines[1:] # we skip the line with the column names - ] - - return examples - - -def convert_examples_to_features( - examples: List[InputExample], - label_list: List[str], - max_length: int, - tokenizer: PreTrainedTokenizer, - pad_token_segment_id=0, - pad_on_left=False, - pad_token=0, - mask_padding_with_zero=True, -) -> List[InputFeatures]: - """ - Loads a data file into a list of `InputFeatures` - """ - - label_map = {label: i for i, label in enumerate(label_list)} - - features = [] - for ex_index, example in tqdm.tqdm(enumerate(examples), desc="convert examples to features"): - if ex_index % 10000 == 0: - logger.info("Writing example %d of %d" % (ex_index, len(examples))) - choices_inputs = [] - for _ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)): - text_a = context - if example.question.find("_") != -1: - # this is for cloze question - text_b = example.question.replace("_", ending) - else: - text_b = example.question + " " + ending - - inputs = tokenizer.encode_plus( - text_a, - text_b, - add_special_tokens=True, - max_length=max_length, - pad_to_max_length=True, - return_overflowing_tokens=True, - ) - if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0: - logger.info( - "Attention! you are cropping tokens (swag task is ok). " - "If you are training ARC and RACE and you are poping question + options," - "you need to try to use a bigger max seq length!" - ) - - choices_inputs.append(inputs) - - label = label_map[example.label] - - input_ids = [x["input_ids"] for x in choices_inputs] - attention_mask = ( - [x["attention_mask"] for x in choices_inputs] if "attention_mask" in choices_inputs[0] else None - ) - token_type_ids = ( - [x["token_type_ids"] for x in choices_inputs] if "token_type_ids" in choices_inputs[0] else None - ) - - features.append( - InputFeatures( - example_id=example.example_id, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - label=label, - ) - ) - - for f in features[:2]: - logger.info("*** Example ***") - logger.info("feature: %s" % f) - - return features diff --git a/orttraining/pytorch_frontend_examples/mnist_training.py b/orttraining/pytorch_frontend_examples/mnist_training.py deleted file mode 100644 index dc9b3f654400c..0000000000000 --- a/orttraining/pytorch_frontend_examples/mnist_training.py +++ /dev/null @@ -1,200 +0,0 @@ -## This code is from https://github.com/pytorch/examples/blob/master/mnist/main.py -## with modification to do training using onnxruntime as backend on cuda device. -## A private PyTorch build from https://aiinfra.visualstudio.com/Lotus/_git/pytorch (ORTTraining branch) is needed to run the demo. - -## Model testing is not complete. - -import argparse -import os - -import numpy as np # noqa: F401 -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim # noqa: F401 -from mpi4py import MPI -from torchvision import datasets, transforms - -from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer - -try: # noqa: SIM105 - from onnxruntime.capi._pybind_state import set_cuda_device_id -except ImportError: - pass - - -class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, x): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - return out - - -def my_loss(x, target): - return F.nll_loss(F.log_softmax(x, dim=1), target) - - -def train_with_trainer(args, trainer, device, train_loader, epoch): - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - - learning_rate = torch.tensor([args.lr]) - loss = trainer.train_step(data, target, learning_rate) - - # Since the output corresponds to [loss_desc, probability_desc], the first value is taken as loss. - if batch_idx % args.log_interval == 0: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, - batch_idx * len(data), - len(train_loader.dataset), - 100.0 * batch_idx / len(train_loader), - loss[0], - ) - ) - - -# TODO: comple this once ORT training can do evaluation. -def test_with_trainer(args, trainer, device, test_loader): - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - output = F.log_softmax(trainer.eval_step(data, fetches=["probability"]), dim=1) - test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print( - "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( - test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) - ) - ) - - -def mnist_model_description(): - input_desc = IODescription("input1", ["batch", 784], torch.float32) - label_desc = IODescription( - "label", - [ - "batch", - ], - torch.int64, - num_classes=10, - ) - loss_desc = IODescription("loss", [], torch.float32) - probability_desc = IODescription("probability", ["batch", 10], torch.float32) - return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc]) - - -def main(): - # Training settings - parser = argparse.ArgumentParser(description="PyTorch MNIST Example") - parser.add_argument( - "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" - ) - parser.add_argument("--epochs", type=int, default=10, metavar="N", help="number of epochs to train (default: 10)") - parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=10, - metavar="N", - help="how many batches to wait before logging training status", - ) - - args = parser.parse_args() - use_cuda = not args.no_cuda and torch.cuda.is_available() - - torch.manual_seed(args.seed) - - kwargs = {"num_workers": 0, "pin_memory": True} - train_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "../data", - train=True, - download=True, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.batch_size, - shuffle=True, - **kwargs, - ) - test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "../data", - train=False, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.test_batch_size, - shuffle=True, - **kwargs, - ) - - comm = MPI.COMM_WORLD - args.local_rank = ( - int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) if ("OMPI_COMM_WORLD_LOCAL_RANK" in os.environ) else 0 - ) - args.world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) if ("OMPI_COMM_WORLD_RANK" in os.environ) else 0 - args.world_size = comm.Get_size() - if use_cuda: - torch.cuda.set_device(args.local_rank) - device = torch.device("cuda", args.local_rank) - args.n_gpu = 1 - set_cuda_device_id(args.local_rank) - else: - device = torch.device("cpu") - - input_size = 784 - hidden_size = 500 - num_classes = 10 - model = NeuralNet(input_size, hidden_size, num_classes) - - model_desc = mnist_model_description() - # use log_interval as gradient accumulate steps - trainer = ORTTrainer( - model, - my_loss, - model_desc, - "SGDOptimizer", - None, - IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ), - device, - 1, - args.world_rank, - args.world_size, - use_mixed_precision=False, - allreduce_post_accumulation=True, - ) - print("\nBuild ort model done.") - - for epoch in range(1, args.epochs + 1): - train_with_trainer(args, trainer, device, train_loader, epoch) - test_with_trainer(args, trainer, device, test_loader) - - -if __name__ == "__main__": - main() diff --git a/samples/python/training/orttrainer/mnist/mnist_original.onnx b/samples/python/training/orttrainer/mnist/mnist_original.onnx deleted file mode 100644 index 15931affb5ccf..0000000000000 Binary files a/samples/python/training/orttrainer/mnist/mnist_original.onnx and /dev/null differ diff --git a/samples/python/training/orttrainer/mnist/ort_mnist.py b/samples/python/training/orttrainer/mnist/ort_mnist.py deleted file mode 100644 index 8f8ccf373ccf6..0000000000000 --- a/samples/python/training/orttrainer/mnist/ort_mnist.py +++ /dev/null @@ -1,174 +0,0 @@ -# This code is from https://github.com/pytorch/examples/blob/master/mnist/main.py -# with modification to do training using onnxruntime as backend on cuda device. - -import argparse -import os - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchvision import datasets, transforms - -import onnxruntime -from onnxruntime.training import ORTTrainer, ORTTrainerOptions, optim - - -# Pytorch model -class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, input1): - out = self.fc1(input1) - out = self.relu(out) - out = self.fc2(out) - return out - - -# ONNX Runtime training -def mnist_model_description(): - return { - "inputs": [("input1", ["batch", 784]), ("label", ["batch"])], - "outputs": [("loss", [], True), ("probability", ["batch", 10])], - } - - -def my_loss(x, target): - return F.nll_loss(F.log_softmax(x, dim=1), target) - - -# Helpers -def train(log_interval, trainer, device, train_loader, epoch, train_steps): - for batch_idx, (data, target) in enumerate(train_loader): - if batch_idx == train_steps: - break - - # Fetch data - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - - # Train step - loss, prob = trainer.train_step(data, target) - - # Stats - if batch_idx % log_interval == 0: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, batch_idx * len(data), len(train_loader.dataset), 100.0 * batch_idx / len(train_loader), loss - ) - ) - - -def test(trainer, device, test_loader): - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - - # Using fetches around without eval_step to not pass 'target' as input - trainer._train_step_info.fetches = ["probability"] - output = F.log_softmax(trainer.eval_step(data), dim=1) - trainer._train_step_info.fetches = [] - - # Stats - test_loss += F.nll_loss(output, target, reduction="sum").item() - pred = output.argmax(dim=1, keepdim=True) - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print( - "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( - test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) - ) - ) - - -def main(): - # Training settings - parser = argparse.ArgumentParser(description="ONNX Runtime MNIST Example") - parser.add_argument( - "--train-steps", - type=int, - default=-1, - metavar="N", - help="number of steps to train. Set -1 to run through whole dataset (default: -1)", - ) - parser.add_argument( - "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" - ) - parser.add_argument("--epochs", type=int, default=1, metavar="N", help="number of epochs to train (default: 1)") - parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=10, - metavar="N", - help="how many batches to wait before logging training status", - ) - parser.add_argument("--save-path", type=str, default="", help="Path for Saving the current Model state") - - # Basic setup - args = parser.parse_args() - if not args.no_cuda and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - torch.manual_seed(args.seed) - onnxruntime.set_seed(args.seed) - - # Data loader - train_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "./data", - train=True, - download=True, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.batch_size, - shuffle=True, - ) - - if args.test_batch_size > 0: - test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "./data", - train=False, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.test_batch_size, - shuffle=True, - ) - - # Modeling - model = NeuralNet(784, 500, 10) - model_desc = mnist_model_description() - optim_config = optim.SGDConfig(lr=args.lr) - opts = {"device": {"id": device}} - opts = ORTTrainerOptions(opts) - - trainer = ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - - # Train loop - for epoch in range(1, args.epochs + 1): - train(args.log_interval, trainer, device, train_loader, epoch, args.train_steps) - if args.test_batch_size > 0: - test(trainer, device, test_loader) - - # Save model - if args.save_path: - torch.save(model.state_dict(), os.path.join(args.save_path, "mnist_cnn.pt")) - - -if __name__ == "__main__": - main() diff --git a/samples/python/training/orttrainer/mnist/pytorch_mnist.py b/samples/python/training/orttrainer/mnist/pytorch_mnist.py deleted file mode 100644 index 2e451d85f62e8..0000000000000 --- a/samples/python/training/orttrainer/mnist/pytorch_mnist.py +++ /dev/null @@ -1,157 +0,0 @@ -import argparse -import os - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torchvision import datasets, transforms - - -# Pytorch model -class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, input1): - out = self.fc1(input1) - out = self.relu(out) - out = self.fc2(out) - return out - - -def my_loss(x, target, is_train=True): - if is_train: - return F.nll_loss(F.log_softmax(x, dim=1), target) - else: - return F.nll_loss(F.log_softmax(x, dim=1), target, reduction="sum") - - -# Helpers -def train(args, model, device, train_loader, optimizer, epoch): - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - if batch_idx == args.train_steps: - break - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - optimizer.zero_grad() - output = model(data) - loss = my_loss(output, target) - loss.backward() - optimizer.step() - if batch_idx % args.log_interval == 0: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, - batch_idx * len(data), - len(train_loader.dataset), - 100.0 * batch_idx / len(train_loader), - loss.item(), - ) - ) - - -def test(model, device, test_loader): - model.eval() - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - output = model(data) - # Stats - test_loss += my_loss(output, target, False).item() - pred = output.argmax(dim=1, keepdim=True) - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print( - "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( - test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) - ) - ) - - -def main(): - # Training settings - parser = argparse.ArgumentParser(description="PyTorch MNIST Example") - parser.add_argument( - "--train-steps", - type=int, - default=-1, - metavar="N", - help="number of steps to train. Set -1 to run through whole dataset (default: -1)", - ) - parser.add_argument( - "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" - ) - parser.add_argument("--epochs", type=int, default=1, metavar="N", help="number of epochs to train (default: 1)") - parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=10, - metavar="N", - help="how many batches to wait before logging training status", - ) - parser.add_argument("--save-path", type=str, default="", help="Path for Saving the current Model") - - # Basic setup - args = parser.parse_args() - if not args.no_cuda and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - torch.manual_seed(args.seed) - - # Data loader - train_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "./data", - train=True, - download=True, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.batch_size, - shuffle=True, - ) - - if args.test_batch_size > 0: - test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "./data", - train=False, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.test_batch_size, - shuffle=True, - ) - - # Modeling - model = NeuralNet(784, 500, 10).to(device) - optimizer = optim.SGD(model.parameters(), lr=args.lr) - - # Train loop - for epoch in range(1, args.epochs + 1): - train(args, model, device, train_loader, optimizer, epoch) - if args.test_batch_size > 0: - test(model, device, test_loader) - - # Save model - if args.save_path: - torch.save(model.state_dict(), os.path.join(args.save_path, "mnist_cnn.pt")) - - -if __name__ == "__main__": - main() diff --git a/samples/python/training/orttrainer/pytorch_transformer/README.md b/samples/python/training/orttrainer/pytorch_transformer/README.md deleted file mode 100644 index cda8cba6ca0ad..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/README.md +++ /dev/null @@ -1,33 +0,0 @@ -# TransformerModel example - -This example was adapted from Pytorch's [Sequence-to-Sequence Modeling with nn.Transformer and TorchText](https://pytorch.org/tutorials/beginner/transformer_tutorial.html) tutorial - -## Requirements - -* PyTorch 1.6+ -* TorchText 0.6+ -* ONNX Runtime 1.5+ - -## Running PyTorch version - -```bash -python pt_train.py -``` - -## Running ONNX Runtime version - -```bash -python ort_train.py -``` - -## Optional arguments - -| Argument | Description | Default | -| :---------------- | :-----------------------------------------------------: | --------: | -| --batch-size | input batch size for training | 20 | -| --test-batch-size | input batch size for testing | 20 | -| --epochs | number of epochs to train | 2 | -| --lr | learning rate | 0.001 | -| --no-cuda | disables CUDA training | False | -| --seed | random seed | 1 | -| --log-interval | how many batches to wait before logging training status | 200 | diff --git a/samples/python/training/orttrainer/pytorch_transformer/ort_train.py b/samples/python/training/orttrainer/pytorch_transformer/ort_train.py deleted file mode 100644 index 551e878cc9035..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/ort_train.py +++ /dev/null @@ -1,89 +0,0 @@ -import argparse - -import torch -from ort_utils import my_loss, transformer_model_description_dynamic_axes -from pt_model import TransformerModel -from utils import get_batch, prepare_data - -import onnxruntime - - -def train(trainer, data_source, device, epoch, args, bptt=35): - total_loss = 0.0 - for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)): - data, targets = get_batch(data_source, i) - - loss, pred = trainer.train_step(data, targets) - total_loss += loss.item() - if batch % args.log_interval == 0 and batch > 0: - cur_loss = total_loss / args.log_interval - print( - "epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}".format( - epoch, batch, len(data_source) // bptt, cur_loss - ) - ) - total_loss = 0 - - -def evaluate(trainer, data_source, bptt=35): - total_loss = 0.0 - with torch.no_grad(): - for i in range(0, data_source.size(0) - 1, bptt): - data, targets = get_batch(data_source, i) - loss, pred = trainer.eval_step(data, targets) - total_loss += len(data) * loss.item() - return total_loss / (len(data_source) - 1) - - -if __name__ == "__main__": - # Training settings - parser = argparse.ArgumentParser(description="PyTorch TransformerModel example") - parser.add_argument( - "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=20, metavar="N", help="input batch size for testing (default: 20)" - ) - parser.add_argument("--epochs", type=int, default=2, metavar="N", help="number of epochs to train (default: 2)") - parser.add_argument("--lr", type=float, default=0.001, metavar="LR", help="learning rate (default: 0.001)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=200, - metavar="N", - help="how many batches to wait before logging training status (default: 200)", - ) - - # Basic setup - args = parser.parse_args() - if not args.no_cuda and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - torch.manual_seed(args.seed) - onnxruntime.set_seed(args.seed) - - # Model - optim_config = onnxruntime.training.optim.SGDConfig(lr=args.lr) - model_desc = transformer_model_description_dynamic_axes() - model = TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device) - - # Preparing data - train_data, val_data, test_data = prepare_data(device, args.batch_size, args.test_batch_size) - trainer = onnxruntime.training.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss) - - # Train - for epoch in range(1, args.epochs + 1): - train(trainer, train_data, device, epoch, args) - val_loss = evaluate(trainer, val_data) - print("-" * 89) - print(f"| end of epoch {epoch:3d} | valid loss {val_loss:5.2f} | ") - print("-" * 89) - - # Evaluate - test_loss = evaluate(trainer, test_data) - print("=" * 89) - print(f"| End of training | test loss {test_loss:5.2f}") - print("=" * 89) diff --git a/samples/python/training/orttrainer/pytorch_transformer/ort_utils.py b/samples/python/training/orttrainer/pytorch_transformer/ort_utils.py deleted file mode 100644 index 73992f5596f5f..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/ort_utils.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch - -from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription -from onnxruntime.capi.ort_trainer import ModelDescription as Legacy_ModelDescription - - -def my_loss(x, target): - x = x.view(-1, 28785) - return torch.nn.CrossEntropyLoss()(x, target) - - -def transformer_model_description(bptt=35, batch_size=20, ntokens=28785): - model_desc = { - "inputs": [("input1", [bptt, batch_size]), ("label", [bptt * batch_size])], - "outputs": [("loss", [], True), ("predictions", [bptt, batch_size, ntokens])], - } - return model_desc - - -def transformer_model_description_dynamic_axes(ntokens=28785): - model_desc = { - "inputs": [("input1", ["bptt", "batch_size"]), ("label", ["bptt_x_batch_size"])], - "outputs": [("loss", [], True), ("predictions", ["bptt", "batch_size", ntokens])], - } - return model_desc - - -def legacy_transformer_model_description(bptt=35, batch_size=20, ntokens=28785): - input_desc = Legacy_IODescription("input1", [bptt, batch_size]) - label_desc = Legacy_IODescription("label", [bptt * batch_size]) - loss_desc = Legacy_IODescription("loss", []) - predictions_desc = Legacy_IODescription("predictions", [bptt, batch_size, ntokens]) - return ( - Legacy_ModelDescription([input_desc, label_desc], [loss_desc, predictions_desc]), - Legacy_IODescription("__learning_rate", [1]), - ) - - -def legacy_transformer_model_description_dynamic_axes(ntokens=28785): - input_desc = Legacy_IODescription("input1", ["bptt", "batch_size"]) - label_desc = Legacy_IODescription("label", ["bptt_x_batch_size"]) - loss_desc = Legacy_IODescription("loss", []) - predictions_desc = Legacy_IODescription("predictions", ["bptt", "batch_size", ntokens]) - return ( - Legacy_ModelDescription([input_desc, label_desc], [loss_desc, predictions_desc]), - Legacy_IODescription("__learning_rate", [1]), - ) diff --git a/samples/python/training/orttrainer/pytorch_transformer/pt_model.py b/samples/python/training/orttrainer/pytorch_transformer/pt_model.py deleted file mode 100644 index 4f2e03192c6cf..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/pt_model.py +++ /dev/null @@ -1,62 +0,0 @@ -import math - -import torch -import torch.nn as nn - - -class TransformerModel(nn.Module): - def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): - super().__init__() - from torch.nn import TransformerEncoder, TransformerEncoderLayer - - self.model_type = "Transformer" - self.input1_mask = None - self.pos_encoder = PositionalEncoding(ninp, dropout) - encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) - self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) - self.encoder = nn.Embedding(ntoken, ninp) - self.ninp = ninp - self.decoder = nn.Linear(ninp, ntoken) - - self.init_weights() - - def _generate_square_subsequent_mask(self, sz): - mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) - mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, 0.0) - return mask - - def init_weights(self): - initrange = 0.1 - self.encoder.weight.data.uniform_(-initrange, initrange) - self.decoder.bias.data.zero_() - self.decoder.weight.data.uniform_(-initrange, initrange) - - def forward(self, input1): - if self.input1_mask is None or self.input1_mask.size(0) != input1.size(0): - device = input1.device - mask = self._generate_square_subsequent_mask(input1.size(0)).to(device) - self.input1_mask = mask - - input1 = self.encoder(input1) * math.sqrt(self.ninp) - input1 = self.pos_encoder(input1) - output = self.transformer_encoder(input1, self.input1_mask) - output = self.decoder(output) - return output - - -class PositionalEncoding(nn.Module): - def __init__(self, d_model, dropout=0.1, max_len=5000): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - - pe = torch.zeros(max_len, d_model) - position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0).transpose(0, 1) - self.register_buffer("pe", pe) - - def forward(self, x): - x = x + self.pe[: x.size(0), :] - return self.dropout(x) diff --git a/samples/python/training/orttrainer/pytorch_transformer/pt_train.py b/samples/python/training/orttrainer/pytorch_transformer/pt_train.py deleted file mode 100644 index a197fb50357e9..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/pt_train.py +++ /dev/null @@ -1,94 +0,0 @@ -import argparse - -import torch -import torch.nn as nn -from pt_model import TransformerModel -from utils import get_batch, prepare_data - - -def train(model, data_source, device, epoch, args, bptt=35): - total_loss = 0.0 - model.train() - for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)): - data, targets = get_batch(data_source, i) - - optimizer.zero_grad() - output = model(data) - loss = criterion(output.view(-1, 28785), targets) - loss.backward() - optimizer.step() - - total_loss += loss.item() - if batch % args.log_interval == 0 and batch > 0: - cur_loss = total_loss / args.log_interval - print( - "epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}".format( - epoch, batch, len(data_source) // bptt, cur_loss - ) - ) - total_loss = 0 - - -def evaluate(model, data_source, criterion, bptt=35): - total_loss = 0.0 - model.eval() - with torch.no_grad(): - for i in range(0, data_source.size(0) - 1, bptt): - data, targets = get_batch(data_source, i) - output = model(data) - output_flat = output.view(-1, 28785) - total_loss += len(data) * criterion(output_flat, targets).item() - return total_loss / (len(data_source) - 1) - - -if __name__ == "__main__": - # Training settings - parser = argparse.ArgumentParser(description="PyTorch TransformerModel example") - parser.add_argument( - "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=20, metavar="N", help="input batch size for testing (default: 20)" - ) - parser.add_argument("--epochs", type=int, default=2, metavar="N", help="number of epochs to train (default: 2)") - parser.add_argument("--lr", type=float, default=0.001, metavar="LR", help="learning rate (default: 0.001)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=200, - metavar="N", - help="how many batches to wait before logging training status (default: 200)", - ) - - # Basic setup - args = parser.parse_args() - if not args.no_cuda and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - torch.manual_seed(args.seed) - - # Model - criterion = nn.CrossEntropyLoss() - lr = 0.001 - model = TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device) - optimizer = torch.optim.SGD(model.parameters(), lr=lr) - - # Preparing data - train_data, val_data, test_data = prepare_data(device, args.batch_size, args.test_batch_size) - - # Train - for epoch in range(1, args.epochs + 1): - train(model, train_data, device, epoch, args) - val_loss = evaluate(model, val_data, criterion) - print("-" * 89) - print(f"| end of epoch {epoch:3d} | valid loss {val_loss:5.2f} | ") - print("-" * 89) - - # Evaluate - test_loss = evaluate(model, test_data, criterion) - print("=" * 89) - print(f"| End of training | test loss {test_loss:5.2f}") - print("=" * 89) diff --git a/samples/python/training/orttrainer/pytorch_transformer/utils.py b/samples/python/training/orttrainer/pytorch_transformer/utils.py deleted file mode 100644 index 3be8b6cf3f420..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/utils.py +++ /dev/null @@ -1,59 +0,0 @@ -import os - -import torch -from torchtext.data.utils import get_tokenizer -from torchtext.utils import download_from_url, extract_archive -from torchtext.vocab import build_vocab_from_iterator - - -def batchify(data, bsz, device): - # Divide the dataset into bsz parts. - nbatch = data.size(0) // bsz - # Trim off any extra elements that wouldn't cleanly fit (remainders). - data = data.narrow(0, 0, nbatch * bsz) - # Evenly divide the data across the bsz batches. - data = data.view(bsz, -1).t().contiguous() - return data.to(device) - - -def get_batch(source, i, bptt=35): - seq_len = min(bptt, len(source) - 1 - i) - data = source[i : i + seq_len] - target = source[i + 1 : i + 1 + seq_len].view(-1) - return data, target - - -def prepare_data(device="cpu", train_batch_size=20, eval_batch_size=20, data_dir=None): - url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip" - - download_path = ".data_wikitext_2_v1" - extract_path = None - if data_dir: - download_path = os.path.join(data_dir, "download") - os.makedirs(download_path, exist_ok=True) - download_path = os.path.join(download_path, "wikitext-2-v1.zip") - - extract_path = os.path.join(data_dir, "extracted") - os.makedirs(extract_path, exist_ok=True) - - test_filepath, valid_filepath, train_filepath = extract_archive( - download_from_url(url, root=download_path), to_path=extract_path - ) - tokenizer = get_tokenizer("basic_english") - vocab = build_vocab_from_iterator(map(tokenizer, iter(open(train_filepath, encoding="utf8")))) # noqa: SIM115 - - def data_process(raw_text_iter): - data = [torch.tensor([vocab[token] for token in tokenizer(item)], dtype=torch.long) for item in raw_text_iter] - return torch.cat(tuple(filter(lambda t: t.numel() > 0, data))) - - train_data = data_process(iter(open(train_filepath, encoding="utf8"))) # noqa: SIM115 - val_data = data_process(iter(open(valid_filepath, encoding="utf8"))) # noqa: SIM115 - test_data = data_process(iter(open(test_filepath, encoding="utf8"))) # noqa: SIM115 - - device = torch.device(device) - - train_data = batchify(train_data, train_batch_size, device) - val_data = batchify(val_data, eval_batch_size, device) - test_data = batchify(test_data, eval_batch_size, device) - - return train_data, val_data, test_data diff --git a/setup.py b/setup.py index 4df48239c8cbd..608bffc082e07 100644 --- a/setup.py +++ b/setup.py @@ -196,7 +196,7 @@ def run(self): "libcublasLt.so.11", "libcublasLt.so.12", "libcudart.so.11.0", - "libcudart.so.12.0", + "libcudart.so.12", "libcudnn.so.8", "libcufft.so.10", "libcufft.so.11", @@ -398,7 +398,6 @@ def finalize_options(self): "onnxruntime", "onnxruntime.backend", "onnxruntime.capi", - "onnxruntime.capi.training", "onnxruntime.datasets", "onnxruntime.tools", "onnxruntime.tools.mobile_helpers", diff --git a/tools/android_custom_build/Dockerfile b/tools/android_custom_build/Dockerfile index 66b6a36e5a8c0..754a6633b0c62 100644 --- a/tools/android_custom_build/Dockerfile +++ b/tools/android_custom_build/Dockerfile @@ -55,7 +55,7 @@ WORKDIR /workspace # install Android SDK and tools ENV ANDROID_HOME=~/android-sdk -ENV NDK_VERSION=26.0.10792818 +ENV NDK_VERSION=26.1.10909125 ENV ANDROID_NDK_HOME=${ANDROID_HOME}/ndk/${NDK_VERSION} RUN aria2c -q -d /tmp -o cmdline-tools.zip \ diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index e0559419ef8c7..6bd3e2533c045 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1171,9 +1171,9 @@ def generate_build_tree( "-Donnxruntime_USE_OPENVINO_AUTO=" + ("ON" if args.use_openvino.startswith("AUTO") else "OFF"), ] - # TensorRT and OpenVINO providers currently only support + # VitisAI and OpenVINO providers currently only support # full_protobuf option. - if args.use_full_protobuf or args.use_tensorrt or args.use_openvino or args.use_vitisai or args.gen_doc: + if args.use_full_protobuf or args.use_openvino or args.use_vitisai or args.gen_doc: cmake_args += ["-Donnxruntime_USE_FULL_PROTOBUF=ON", "-DProtobuf_USE_STATIC_LIBS=ON"] if args.use_tvm and args.llvm_path is not None: diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml new file mode 100644 index 0000000000000..aee42d3675087 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/py-cuda-packaging-pipeline.yml @@ -0,0 +1,39 @@ +trigger: none + +parameters: + - name: enable_linux_gpu + type: boolean + default: true + - name: enable_windows_gpu + type: boolean + default: true + - name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + - name: cuda_version + type: string + default: '12.2' + values: + - 11.8 + - 12.2 + +resources: + repositories: + - repository: manylinux + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 + +stages: + - template: stages/py-cuda-packaging-stage.yml + parameters: + enable_linux_gpu: ${{ parameters.enable_linux_gpu }} + enable_windows_gpu: ${{ parameters.enable_windows_gpu }} + cmake_build_type: ${{ parameters.cmake_build_type }} + cuda_version: ${{ parameters.cuda_version }} \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml new file mode 100644 index 0000000000000..f3d68957d649c --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml @@ -0,0 +1,105 @@ +parameters: +- name: build_py_parameters + displayName: > + Extra parameters to pass to build.py. Don't put newlines in here. + type: string + default: '' + +- name: enable_linux_gpu + displayName: 'Whether Linux GPU package is built.' + type: boolean + default: true + +- name: enable_windows_gpu + displayName: 'Whether Windows GPU package is built.' + type: boolean + default: true + +# TODO: Now the Windows jobs use a different cmake build type. Consider to merge it. +- name: cmake_build_type + type: string + displayName: 'Linux packages cmake build type. Linux Only.' + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +- name: cuda_version + type: string + displayName: 'CUDA version. Windows Only.' + default: '12.2' + values: + - 11.8 + - 12.2 + +stages: +- stage: Python_Packaging + dependsOn: [] + variables: + - name: docker_base_image + ${{ if eq(parameters.cuda_version, '11.8') }}: + value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + ${{ if eq(parameters.cuda_version, '12.2') }}: + value: nvidia/cuda:12.2.2-cudnn8-devel-ubi8 + - name: linux_trt_version + ${{ if eq(parameters.cuda_version, '11.8') }}: + value: 8.6.1.6-1.cuda11.8 + ${{ if eq(parameters.cuda_version, '12.2') }}: + value: 8.6.1.6-1.cuda12.0 + - name: win_trt_home + ${{ if eq(parameters.cuda_version, '11.8') }}: + value: $(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8 + ${{ if eq(parameters.cuda_version, '12.2') }}: + value: $(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0 + - name: win_cuda_home + ${{ if eq(parameters.cuda_version, '11.8') }}: + value: $(Agent.TempDirectory)\v11.8 + ${{ if eq(parameters.cuda_version, '12.2') }}: + value: $(Agent.TempDirectory)\v12.2 + jobs: + - ${{ if eq(parameters.enable_windows_gpu, true) }}: + - template: ../templates/py-win-gpu.yml + parameters: + MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4' + PYTHON_VERSION: '3.8' + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_NAME: gpu + CudaVersion: ${{ parameters.cuda_version }} + + - template: ../templates/py-win-gpu.yml + parameters: + MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4' + PYTHON_VERSION: '3.9' + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_NAME: gpu + CudaVersion: ${{ parameters.cuda_version }} + + - template: ../templates/py-win-gpu.yml + parameters: + MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4' + PYTHON_VERSION: '3.10' + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_NAME: gpu + CudaVersion: ${{ parameters.cuda_version }} + + - template: ../templates/py-win-gpu.yml + parameters: + MACHINE_POOL: 'onnxruntime-Win2022-GPU-T4' + PYTHON_VERSION: '3.11' + EP_BUILD_FLAGS: --use_tensorrt --tensorrt_home=${{ variables.win_trt_home }} --cuda_home=${{ variables.win_cuda_home }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + EP_NAME: gpu + CudaVersion: ${{ parameters.cuda_version }} + + + - ${{ if eq(parameters.enable_linux_gpu, true) }}: + - template: ../templates/py-linux-gpu.yml + parameters: + arch: 'x86_64' + machine_pool: 'onnxruntime-Ubuntu2004-AMD-CPU' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + docker_base_image: ${{ variables.docker_base_image }} + trt_version: ${{ variables.linux_trt_version }} + cuda_version: ${{ parameters.cuda_version }} diff --git a/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml b/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml index 56f6bd56eeed7..e664cf69dec76 100644 --- a/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml +++ b/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml @@ -67,9 +67,9 @@ steps: EM_DIR: '$(Build.SourcesDirectory)/cmake/external/emsdk/upstream/emscripten' - ${{if eq(parameters.WithCache, false)}}: - - task: PythonScript@0 - displayName: '${{parameters.DisplayName}}' - inputs: - scriptPath: '$(Build.SourcesDirectory)/tools/ci_build/build.py' - arguments: ${{parameters.Arguments}} - workingDirectory: '$(Build.BinariesDirectory)' + - script: | + set -e -x + source $(Build.SourcesDirectory)/cmake/external/emsdk/emsdk_env.sh + cd '$(Build.BinariesDirectory)' + python3 '$(Build.SourcesDirectory)/tools/ci_build/build.py' ${{parameters.Arguments}} + displayName: ${{parameters.DisplayName}} diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml index 4573c56963e34..ff7f0957e94ba 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml @@ -34,7 +34,7 @@ steps: displayName: 'Download TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8' - powershell: | Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8\lib" - displayName: 'Append CUDA SDK Directory to PATH' + displayName: 'Append TensorRT Directory to PATH' - ${{ if eq(parameters.CudaVersion, '12.2') }}: - powershell: | @@ -42,7 +42,7 @@ steps: displayName: 'Download TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0' - powershell: | Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\TensorRT-8.6.1.6.Windows10.x86_64.cuda-12.0\lib" - displayName: 'Append CUDA SDK Directory to PATH' + displayName: 'Append TensorRT Directory to PATH' - task: CmdLine@2 inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml index 9282cfccd02f0..e40c4d0e95dc5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml @@ -4,6 +4,7 @@ parameters: - name: EnvSetupScript type: string + default: setup_env.bat - name: job_name_suffix type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index f81b1ddc8b93b..852d688b2dbb1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -90,13 +90,20 @@ jobs: arguments: --new_dir $(Build.BinariesDirectory)/deps workingDirectory: $(Build.BinariesDirectory) - - script: | - set -ex - cd '$(Build.SourcesDirectory)/cmake/external/emsdk' - ./emsdk install 3.1.44 ccache-git-emscripten-64bit - ./emsdk activate 3.1.44 ccache-git-emscripten-64bit - displayName: 'emsdk install and activate ccache for emscripten' - condition: eq('${{ parameters.WithCache }}', 'true') + - ${{if eq(parameters.WithCache, true)}}: + - script: | + set -ex + cd '$(Build.SourcesDirectory)/cmake/external/emsdk' + ./emsdk install 3.1.44 ccache-git-emscripten-64bit + ./emsdk activate 3.1.44 ccache-git-emscripten-64bit + displayName: 'emsdk install and activate ccache for emscripten' + - ${{if eq(parameters.WithCache, false)}}: + - script: | + set -ex + cd '$(Build.SourcesDirectory)/cmake/external/emsdk' + ./emsdk install 3.1.44 + ./emsdk activate 3.1.44 + displayName: 'emsdk install and activate ccache for emscripten' - template: build-linux-wasm-step.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml index f68847afff379..8cc48aac7a3b9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml @@ -17,7 +17,24 @@ parameters: - Release - RelWithDebInfo - MinSizeRel - +- name: docker_base_image + type: string + default: 'nvidia/cuda:11.8.0-cudnn8-devel-ubi8' + values: + - nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + - nvidia/cuda:12.2.2-cudnn8-devel-ubi8 +- name: trt_version + type: string + default: '8.6.1.6-1.cuda11.8' + values: + - 8.6.1.6-1.cuda11.8 + - 8.6.1.6-1.cuda12.0 +- name: cuda_version + type: string + default: '11.8' + values: + - 11.8 + - 12.2 jobs: - job: Linux_py_GPU_Wheels_${{ parameters.arch }} timeoutInMinutes: 240 @@ -26,7 +43,13 @@ jobs: pool: ${{ parameters.machine_pool }} variables: # The build machine pool doesn't have dotnet, so it can't run CG. - skipComponentGovernanceDetection: true + - name: skipComponentGovernanceDetection + value: true + - name: extra_build_args + ${{ if ne(parameters.extra_build_arg, '') }}: + value: -x ${{ parameters.extra_build_arg }} + ${{ if eq(parameters.extra_build_arg, '') }}: + value: '' steps: - checkout: self clean: true @@ -40,12 +63,12 @@ jobs: Context: tools/ci_build/github/linux/docker DockerBuildArgs: " --network=host - --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 - --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg BASEIMAGE=${{ parameters.docker_base_image }} + --build-arg TRT_VERSION=${{ parameters.trt_version }} --build-arg BUILD_UID=$( id -u ) --build-arg PLATFORM=${{ parameters.arch }} " - Repository: onnxruntimecuda118xtrt86build${{ parameters.arch }} + Repository: onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }} - task: Bash@3 @@ -53,8 +76,7 @@ jobs: inputs: targetType: filePath filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh - # please check ONNXRUNTIME_CUDA_VERSION in tools/ci_build/github/linux/build_linux_arm64_python_package.sh - arguments: -i onnxruntimecuda118xtrt86build${{ parameters.arch }} -d "GPU" -c ${{ parameters.cmake_build_type }} -x "${{ parameters.extra_build_arg }}" + arguments: -i onnxruntimecuda${{ replace(parameters.cuda_version, '.', '') }}xtrt86build${{ parameters.arch }} -d "GPU" -c ${{ parameters.cmake_build_type }} $(extra_build_args) - task: PublishBuildArtifacts@1 displayName: 'Publish Artifact: ONNXRuntime python wheel' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml index 0774c3350b9b1..db3782c69cf62 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux.yml @@ -46,9 +46,17 @@ jobs: pool: ${{ parameters.machine_pool }} variables: # The build machine pool doesn't have dotnet, so it can't run CG. - skipComponentGovernanceDetection: true - ORT_CACHE_DIR: $(Agent.TempDirectory)/ort_ccache - TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + - name: skipComponentGovernanceDetection + value: true + - name: ORT_CACHE_DIR + value: $(Agent.TempDirectory)/ort_ccache + - name: TODAY + value: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + - name: extra_build_args + ${{ if ne(parameters.extra_build_arg, '') }}: + value: -x ${{ parameters.extra_build_arg }} + ${{ if eq(parameters.extra_build_arg, '') }}: + value: '' steps: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' @@ -82,7 +90,7 @@ jobs: inputs: targetType: filePath filePath: tools/ci_build/github/linux/run_python_dockerbuild.sh - arguments: -i onnxruntimecpubuildpython${{ parameters.arch }} -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} -x "${{ parameters.extra_build_arg }}" + arguments: -i onnxruntimecpubuildpython${{ parameters.arch }} -d "${{ parameters.device }}" -c ${{ parameters.cmake_build_type }} $(extra_build_args) ${{ if eq(parameters.with_cache, 'true') }}: env: ADDITIONAL_DOCKER_PARAMETER: "--volume $(ORT_CACHE_DIR):/cache -e CCACHE_DIR=/cache -e ORT_BUILD_WITH_CACHE=1" diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml index 919749cac15b6..501251eaff20f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml @@ -14,21 +14,32 @@ parameters: - name: ENV_SETUP_SCRIPT type: string + default: '' - name: BUILD_PY_PARAMETERS displayName: > Extra parameters to pass to build.py. Don't put newlines in here. type: string default: '' - +- name: CudaVersion + type: string + default: '11.8' + values: + - 11.8 + - 12.2 jobs: - job: Win_py_${{ parameters.EP_NAME }}_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }} timeoutInMinutes: 240 workspace: clean: all - pool: ${{ parameters.MACHINE_POOL }} + pool: + name: ${{ parameters.MACHINE_POOL }} +# demands: +# - ImageVersionOverride -equals 1.0.367516 variables: + GRADLE_OPTS: '-Dorg.gradle.daemon=false' VSGenerator: 'Visual Studio 17 2022' + CUDA_MODULE_LOADING: 'LAZY' steps: - checkout: self clean: true @@ -61,10 +72,21 @@ jobs: - template: download-deps.yml - - template: jobs/set-winenv.yml - parameters: - EnvSetupScript: ${{ parameters.ENV_SETUP_SCRIPT }} - DownloadCUDA: true + - ${{ if ne(parameters.ENV_SETUP_SCRIPT, '') }}: + - template: jobs/set-winenv.yml + parameters: + EnvSetupScript: ${{ parameters.ENV_SETUP_SCRIPT }} + ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}: + DownloadCUDA: true + + - ${{ if eq(parameters.ENV_SETUP_SCRIPT, '') }}: + - template: jobs/download_win_gpu_library.yml + parameters: + CudaVersion: ${{ parameters.CudaVersion }} + ${{ if or(contains(parameters.EP_BUILD_FLAGS, 'use_cuda'), contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt')) }}: + DownloadCUDA: true + ${{ if contains(parameters.EP_BUILD_FLAGS, 'use_tensorrt') }}: + DownloadTRT: true - task: PythonScript@0 displayName: 'Update deps.txt' diff --git a/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml b/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml index 8cc7f63a193cc..b8dba89b0b899 100644 --- a/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml @@ -3,7 +3,7 @@ parameters: - name: AndroidNdkVersion type: string - default: "26.0.10792818" # LTS version + default: "26.1.10909125" # LTS version steps: - bash: | diff --git a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml index c649883ea0d8b..9982b36509b68 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml @@ -65,7 +65,6 @@ stages: clean: all steps: - checkout: self - fetchDepth: 1 submodules: false - script: | git submodule sync -- cmake/external/onnx diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml index ed010b5619db5..d7ffc1828c943 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml @@ -40,7 +40,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'Debug' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --build_java --build_nodejs --build_wheel --disable_memleak_checker msbuildPlatform: x64 @@ -59,7 +58,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat buildArch: x64 # Compare to our Nuget packaging pipeline, this job has "--build_wheel" but doesn't have "--enable_lto --disable_rtti --use_telemetry --enable_wcos" # Python bindings use typeid so I can't disable RTTI here. If it causes a problem, we will need to split this job to two jobs. @@ -80,7 +78,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --build_wheel --use_dnnl --build_java msbuildPlatform: x64 @@ -101,7 +98,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --build_wheel --use_xnnpack msbuildPlatform: x64 @@ -120,7 +116,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --use_winml --enable_wcos --disable_rtti --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.22000.0 msbuildPlatform: x64 @@ -160,7 +155,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'Debug' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --enable_training --build_wheel --disable_memleak_checker msbuildPlatform: x64 @@ -179,7 +173,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --enable_training --build_wheel msbuildPlatform: x64 @@ -198,7 +191,6 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env.bat buildArch: x64 additionalBuildFlags: --enable_training_apis msbuildPlatform: x64 @@ -215,10 +207,17 @@ stages: - stage: x64_release_azure dependsOn: [] jobs: + - job: + steps: + - powershell: | + Write-Host "##vso[task.prependpath]$(Build.BinariesDirectory)\RelWithDebInfo\_deps\vcpkg-src\installed\x86-windows\bin" + $env:PATH + Write-Host "##vso[task.prependpath]$(Build.BinariesDirectory)\RelWithDebInfo\_deps\vcpkg-src\installed\x64-windows\bin" + $env:PATH + displayName: 'Append x64-windows and x86-windows to PATH' - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_azure.bat buildArch: x64 additionalBuildFlags: --use_azure --use_lock_free_queue msbuildPlatform: x64 @@ -231,3 +230,5 @@ stages: GenerateDocumentation: false WITH_CACHE: true MachinePool: 'onnxruntime-Win-CPU-2022' + + diff --git a/tools/ci_build/github/linux/build_linux_arm64_python_package.sh b/tools/ci_build/github/linux/build_linux_python_package.sh similarity index 78% rename from tools/ci_build/github/linux/build_linux_arm64_python_package.sh rename to tools/ci_build/github/linux/build_linux_python_package.sh index 516f320cd64c4..3c1c65c9a6862 100755 --- a/tools/ci_build/github/linux/build_linux_arm64_python_package.sh +++ b/tools/ci_build/github/linux/build_linux_python_package.sh @@ -15,9 +15,11 @@ do case "${parameter_Option}" in #GPU or CPU. d) BUILD_DEVICE=${OPTARG};; -p) PYTHON_EXES=(${OPTARG});; -x) EXTRA_ARG=(${OPTARG});; +p) PYTHON_EXES=${OPTARG};; +x) EXTRA_ARG=${OPTARG};; c) BUILD_CONFIG=${OPTARG};; +*) echo "Usage: $0 -d [-p ] [-x ] [-c ]" + exit 1;; esac done @@ -48,7 +50,7 @@ if [ "$ARCH" == "x86_64" ] && [ "$GCC_VERSION" -ge 9 ]; then fi echo "EXTRA_ARG:" -echo $EXTRA_ARG +echo "$EXTRA_ARG" if [ "$EXTRA_ARG" != "" ]; then BUILD_ARGS+=("$EXTRA_ARG") @@ -60,19 +62,19 @@ if [ "$ARCH" == "x86_64" ]; then fi if [ "$BUILD_DEVICE" == "GPU" ]; then + SHORT_CUDA_VERSION=$(echo $CUDA_VERSION | sed 's/\([[:digit:]]\+\.[[:digit:]]\+\)\.[[:digit:]]\+/\1/') #Enable CUDA and TRT EPs. - ONNXRUNTIME_CUDA_VERSION="11.8" - BUILD_ARGS+=("--nvcc_threads=1" "--use_cuda" "--use_tensorrt" "--cuda_version=$ONNXRUNTIME_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$ONNXRUNTIME_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$ONNXRUNTIME_CUDA_VERSION" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80") + BUILD_ARGS+=("--nvcc_threads=1" "--use_cuda" "--use_tensorrt" "--cuda_version=$SHORT_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80") fi export CFLAGS export CXXFLAGS for PYTHON_EXE in "${PYTHON_EXES[@]}" do - rm -rf /build/$BUILD_CONFIG + rm -rf /build/"$BUILD_CONFIG" ${PYTHON_EXE} /onnxruntime_src/tools/ci_build/build.py "${BUILD_ARGS[@]}" - cp /build/$BUILD_CONFIG/dist/*.whl /build/dist + cp /build/"$BUILD_CONFIG"/dist/*.whl /build/dist done which ccache && ccache -sv && ccache -z diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh index 3bca6413100a2..da8a45e00cc90 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_lort.sh @@ -19,7 +19,9 @@ fi export ONNX_ML=1 export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" +# This may install PyTorch, which will be overrided by the PyTorch local build below. /opt/python/cp39-cp39/bin/python3.9 -m pip install transformers + # beartype is installed here so that onnxscript installation step won't # install a version PyTorch doesn't like. Once beartype fixes this problem. # We can remove this line. diff --git a/tools/ci_build/github/linux/run_python_dockerbuild.sh b/tools/ci_build/github/linux/run_python_dockerbuild.sh index 18ac6482827f9..ff2ce6f7ff231 100755 --- a/tools/ci_build/github/linux/run_python_dockerbuild.sh +++ b/tools/ci_build/github/linux/run_python_dockerbuild.sh @@ -9,24 +9,32 @@ i) DOCKER_IMAGE=${OPTARG};; d) DEVICE=${OPTARG};; x) BUILD_EXTR_PAR=${OPTARG};; c) BUILD_CONFIG=${OPTARG};; +*) echo "Usage: $0 -i -d [-x ] [-c ]" + exit 1;; esac done -mkdir -p $HOME/.onnx +mkdir -p "${HOME}/.onnx" +DOCKER_SCRIPT_OPTIONS="-d ${DEVICE} -c ${BUILD_CONFIG}" + +if [ "${BUILD_EXTR_PAR}" != "" ] ; then + DOCKER_SCRIPT_OPTIONS+=" -x ${BUILD_EXTR_PAR}" +fi + docker run --rm \ --volume /data/onnx:/data/onnx:ro \ - --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src \ - --volume $BUILD_BINARIESDIRECTORY:/build \ + --volume "${BUILD_SOURCESDIRECTORY}:/onnxruntime_src" \ + --volume "${BUILD_BINARIESDIRECTORY}:/build" \ --volume /data/models:/build/models:ro \ - --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + --volume "${HOME}/.onnx:/home/onnxruntimedev/.onnx" \ -w /onnxruntime_src \ -e NIGHTLY_BUILD \ -e BUILD_BUILDNUMBER \ $ADDITIONAL_DOCKER_PARAMETER \ - $DOCKER_IMAGE tools/ci_build/github/linux/build_linux_arm64_python_package.sh -d $DEVICE -c $BUILD_CONFIG -x $BUILD_EXTR_PAR + $DOCKER_IMAGE tools/ci_build/github/linux/build_linux_python_package.sh $DOCKER_SCRIPT_OPTIONS -sudo rm -rf $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/onnxruntime $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/pybind11 \ - $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/models $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/_deps \ - $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/CMakeFiles -cd $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG -find -executable -type f > $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/perms.txt +sudo rm -rf "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/onnxruntime" "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/pybind11" \ + "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/models" "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/_deps" \ + "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/CMakeFiles" +cd "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}" +find -executable -type f > "${BUILD_BINARIESDIRECTORY}/${BUILD_CONFIG}/perms.txt" diff --git a/tools/ci_build/github/windows/setup_env_azure.bat b/tools/ci_build/github/windows/setup_env_azure.bat deleted file mode 100644 index 44ba34b0bf23a..0000000000000 --- a/tools/ci_build/github/windows/setup_env_azure.bat +++ /dev/null @@ -1,4 +0,0 @@ -REM Copyright (c) Microsoft Corporation. All rights reserved. -REM Licensed under the MIT License. -set PATH=%cd%\RelWithDebInfo\_deps\vcpkg-src\installed\x64-windows\bin;%cd%\RelWithDebInfo\_deps\vcpkg-src\installed\x86-windows\bin;%PATH% -set GRADLE_OPTS=-Dorg.gradle.daemon=false diff --git a/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py b/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py index 113b5398f3981..9eccb7c36455f 100644 --- a/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py +++ b/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py @@ -10,9 +10,8 @@ import sys import onnx -from onnx import shape_inference -from ..onnx_model_utils import get_opsets_imported +from ..onnx_model_utils import ModelProtoWithShapeInfo, get_opsets_imported from ..reduced_build_config_parser import parse_config cpp_to_tensorproto_type = { @@ -265,15 +264,13 @@ def run_check(model_path: pathlib.Path, mobile_pkg_build_config: pathlib.Path, l ) model_file = model_path.resolve(strict=True) - model = onnx.load(str(model_file)) # we need to run shape inferencing to populate that type info for node outputs. # we will get warnings if the model uses ORT contrib ops (ONNX does not have shape inferencing for those), # and shape inferencing will be lost downstream of those. # TODO: add support for checking ORT format model as it will have full type/shape info for all nodes - model_with_type_info = shape_inference.infer_shapes(model) - - return run_check_with_model(model_with_type_info, mobile_pkg_build_config, logger) + model_wrapper = ModelProtoWithShapeInfo(model_file) + return run_check_with_model(model_wrapper.model_with_shape_info, mobile_pkg_build_config, logger) def main(): diff --git a/tools/python/util/mobile_helpers/usability_checker.py b/tools/python/util/mobile_helpers/usability_checker.py index f8b0bfe707ead..dcb3451a5e0fa 100644 --- a/tools/python/util/mobile_helpers/usability_checker.py +++ b/tools/python/util/mobile_helpers/usability_checker.py @@ -13,6 +13,7 @@ import onnx from ..onnx_model_utils import ( + ModelProtoWithShapeInfo, get_producer_consumer_maps, is_fixed_size_tensor, iterate_graph_per_graph_func, @@ -464,9 +465,9 @@ def check_shapes(graph: onnx.GraphProto, logger: Optional[logging.Logger] = None return dynamic_inputs, num_dynamic_values -def checker(model_path, logger: logging.Logger): - model = onnx.load(model_path) - model_with_shape_info = onnx.shape_inference.infer_shapes(model) +def checker(model_path: pathlib.Path, logger: logging.Logger): + model_with_shape_info_wrapper = ModelProtoWithShapeInfo(model_path) + model_with_shape_info = model_with_shape_info_wrapper.model_with_shape_info # create lookup map for efficiency value_to_shape = {} @@ -541,10 +542,10 @@ def analyze_model(model_path: pathlib.Path, skip_optimize: bool = False, logger: with tempfile.TemporaryDirectory() as tmp: if not skip_optimize: tmp_path = pathlib.Path(tmp) / model_path.name - optimize_model(model_path, tmp_path) + optimize_model(model_path, tmp_path, use_external_initializers=True) model_path = tmp_path - try_eps = checker(str(model_path.resolve(strict=True)), logger) + try_eps = checker(model_path.resolve(strict=True), logger) return try_eps diff --git a/tools/python/util/onnx_model_utils.py b/tools/python/util/onnx_model_utils.py index e662d1623f8bd..5c970430a3a82 100644 --- a/tools/python/util/onnx_model_utils.py +++ b/tools/python/util/onnx_model_utils.py @@ -95,6 +95,7 @@ def optimize_model( output_path: pathlib.Path, level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC, log_level: int = 3, + use_external_initializers: bool = False, ): """ Optimize an ONNX model using ONNX Runtime to the specified level @@ -103,12 +104,25 @@ def optimize_model( :param level: onnxruntime.GraphOptimizationLevel to use. Default is ORT_ENABLE_BASIC. :param log_level: Log level. Defaults to Error (3) so we don't get output about unused initializers being removed. Warning (2) or Info (1) may be desirable in some scenarios. + :param use_external_initializers: Set flag to write initializers to an external file. Required if model > 2GB. + Requires onnxruntime 1.17+ """ so = ort.SessionOptions() so.optimized_model_filepath = str(output_path.resolve()) so.graph_optimization_level = level so.log_severity_level = log_level + # save using external initializers so models > 2 GB are handled + if use_external_initializers: + major, minor, rest = ort.__version__.split(".", 3) + if (int(major), int(minor)) >= (1, 17): + so.add_session_config_entry("session.optimized_model_external_initializers_file_name", "external_data.pb") + else: + raise ValueError( + "ONNX Runtime 1.17 or higher required to save initializers as external data when optimizing model. " + f"Current ONNX Runtime version is {ort.__version__}" + ) + # create session to optimize. this will write the updated model to output_path _ = ort.InferenceSession(str(model_path.resolve(strict=True)), so, providers=["CPUExecutionProvider"]) @@ -366,3 +380,34 @@ def get_optimization_level(level): return ort.GraphOptimizationLevel.ORT_ENABLE_ALL raise ValueError("Invalid optimization level of " + level) + + +class ModelProtoWithShapeInfo: + """ + Class to load an ONNX model and run shape inferencing on it to populate the ValueInfo. + The model_with_shape_info property will contain the updated model. + If the model is > 2GB and uses external data a temporary file is required to run shape inferencing successfully. + This helper class handles automatic removal of the temporary file. + """ + + def __init__(self, model_path: pathlib.Path): + """ + :param model_path: Path to ONNX model to load and run shape inferencing on. + """ + + self.model_path = model_path + + model = onnx.load(str(model_path)) + self.model_with_shape_info = onnx.shape_inference.infer_shapes(model, strict_mode=True) + + # ONNX has a silent failure from the call to infer_shapes when the model is > 2GB. + # We detect that by checking the nodes in the returned model. + self._tmp_model_path = None + if len(model.graph.node) > 0 and len(self.model_with_shape_info.graph.node) == 0: + self._tmp_model_path = pathlib.Path(model_path).with_suffix(".temp_with_shapeinf.onnx") + onnx.shape_inference.infer_shapes_path(str(model_path), str(self._tmp_model_path), strict_mode=True) + self.model_with_shape_info = onnx.load(str(self._tmp_model_path)) + + def __del__(self): + if self._tmp_model_path: + self._tmp_model_path.unlink(missing_ok=True)