diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index a9a78668b4810..cdfb2139730ad 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -339,9 +339,6 @@ configure_file(${ONNXRUNTIME_ROOT}/python/_pybind_state.py.in ${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py) if (onnxruntime_ENABLE_TRAINING) - file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/python/deprecated/*.py" - ) file(GLOB onnxruntime_python_root_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/*.py" ) @@ -419,10 +416,6 @@ if (onnxruntime_ENABLE_TRAINING) "${ORTTRAINING_SOURCE_DIR}/python/training/onnxblock/optim/*" ) endif() -else() - file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/python/training/*.py" - ) endif() if (onnxruntime_BUILD_UNIT_TESTS) @@ -577,9 +570,6 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py $/onnxruntime/capi/ - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_capi_training_srcs} - $/onnxruntime/capi/training/ COMMAND ${CMAKE_COMMAND} -E copy $ $/onnxruntime/capi/ @@ -750,9 +740,6 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils/data/ COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils/hooks/ - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_capi_training_srcs} - $/onnxruntime/capi/training/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_root_srcs} $/onnxruntime/training/ diff --git a/js/node/package-lock.json b/js/node/package-lock.json index e8968bafc4a9f..c1cf8af4bb80e 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -22,7 +22,7 @@ "jsonc": "^2.0.0", "minimist": "^1.2.8", "node-addon-api": "^6.0.0", - "onnx-proto": "^8.0.1" + "protobufjs": "^7.2.4" } }, "../common": { @@ -97,12 +97,6 @@ "integrity": "sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==", "dev": true }, - "node_modules/@types/long": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/@types/long/-/long-4.0.2.tgz", - "integrity": "sha512-MqTGEo5bj5t157U6fA/BiDynNkn0YknVdh48CMPkTSpFTVmvao5UQmm7uEF6xBEo7qIMAlY/JSleYaE6VOdpaA==", - "dev": true - }, "node_modules/@types/minimist": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/@types/minimist/-/minimist-1.2.2.tgz", @@ -528,9 +522,9 @@ "dev": true }, "node_modules/long": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz", - "integrity": "sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", + "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==", "dev": true }, "node_modules/lru-cache": { @@ -663,15 +657,6 @@ "node": "^12.13.0 || ^14.15.0 || >=16.0.0" } }, - "node_modules/onnx-proto": { - "version": "8.0.1", - "resolved": "https://registry.npmjs.org/onnx-proto/-/onnx-proto-8.0.1.tgz", - "integrity": "sha512-ZpPTqp5dneh2bvavk/QpDsf20JJRArjqTkiMfshGmxR8ocjmfTk80fkW00FwLO7qRtybo9NPugcWQrumHYctLQ==", - "dev": true, - "dependencies": { - "protobufjs": "^6.11.2" - } - }, "node_modules/onnxruntime-common": { "resolved": "../common", "link": true @@ -690,9 +675,9 @@ } }, "node_modules/protobufjs": { - "version": "6.11.4", - "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-6.11.4.tgz", - "integrity": "sha512-5kQWPaJHi1WoCpjTGszzQ32PG2F4+wRY6BmAT4Vfw56Q2FZ4YZzK20xUYQH4YkfehY1e6QSICrJquM6xXZNcrw==", + "version": "7.2.5", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.5.tgz", + "integrity": "sha512-gGXRSXvxQ7UiPgfw8gevrfRWcTlSbOFg+p/N+JVJEK5VhueL2miT6qTymqAmjr1Q5WbOCyJbyrk6JfWKwlFn6A==", "dev": true, "hasInstallScript": true, "dependencies": { @@ -706,13 +691,11 @@ "@protobufjs/path": "^1.1.2", "@protobufjs/pool": "^1.1.0", "@protobufjs/utf8": "^1.1.0", - "@types/long": "^4.0.1", "@types/node": ">=13.7.0", - "long": "^4.0.0" + "long": "^5.0.0" }, - "bin": { - "pbjs": "bin/pbjs", - "pbts": "bin/pbts" + "engines": { + "node": ">=12.0.0" } }, "node_modules/proxy-from-env": { @@ -789,9 +772,9 @@ ] }, "node_modules/semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "dependencies": { "lru-cache": "^6.0.0" @@ -1070,12 +1053,6 @@ "integrity": "sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==", "dev": true }, - "@types/long": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/@types/long/-/long-4.0.2.tgz", - "integrity": "sha512-MqTGEo5bj5t157U6fA/BiDynNkn0YknVdh48CMPkTSpFTVmvao5UQmm7uEF6xBEo7qIMAlY/JSleYaE6VOdpaA==", - "dev": true - }, "@types/minimist": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/@types/minimist/-/minimist-1.2.2.tgz", @@ -1413,9 +1390,9 @@ "dev": true }, "long": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz", - "integrity": "sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==", + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", + "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==", "dev": true }, "lru-cache": { @@ -1523,15 +1500,6 @@ "set-blocking": "^2.0.0" } }, - "onnx-proto": { - "version": "8.0.1", - "resolved": "https://registry.npmjs.org/onnx-proto/-/onnx-proto-8.0.1.tgz", - "integrity": "sha512-ZpPTqp5dneh2bvavk/QpDsf20JJRArjqTkiMfshGmxR8ocjmfTk80fkW00FwLO7qRtybo9NPugcWQrumHYctLQ==", - "dev": true, - "requires": { - "protobufjs": "^6.11.2" - } - }, "onnxruntime-common": { "version": "file:../common", "requires": { @@ -1549,9 +1517,9 @@ } }, "protobufjs": { - "version": "6.11.4", - "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-6.11.4.tgz", - "integrity": "sha512-5kQWPaJHi1WoCpjTGszzQ32PG2F4+wRY6BmAT4Vfw56Q2FZ4YZzK20xUYQH4YkfehY1e6QSICrJquM6xXZNcrw==", + "version": "7.2.5", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.5.tgz", + "integrity": "sha512-gGXRSXvxQ7UiPgfw8gevrfRWcTlSbOFg+p/N+JVJEK5VhueL2miT6qTymqAmjr1Q5WbOCyJbyrk6JfWKwlFn6A==", "dev": true, "requires": { "@protobufjs/aspromise": "^1.1.2", @@ -1564,9 +1532,8 @@ "@protobufjs/path": "^1.1.2", "@protobufjs/pool": "^1.1.0", "@protobufjs/utf8": "^1.1.0", - "@types/long": "^4.0.1", "@types/node": ">=13.7.0", - "long": "^4.0.0" + "long": "^5.0.0" } }, "proxy-from-env": { @@ -1619,9 +1586,9 @@ "dev": true }, "semver": { - "version": "7.3.8", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.3.8.tgz", - "integrity": "sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==", + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", "dev": true, "requires": { "lru-cache": "^6.0.0" diff --git a/js/node/package.json b/js/node/package.json index 0f8f0e9d2260c..8e591d8f46b9d 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -19,6 +19,7 @@ }, "scripts": { "buildr": "tsc && node ./script/build --config=RelWithDebInfo", + "preprepare": "node -e \"require('node:fs').copyFileSync('./node_modules/long/index.d.ts', './node_modules/long/umd/index.d.ts')\"", "prepare": "tsc --build script test .", "rebuild": "tsc && node ./script/build --rebuild", "rebuildd": "tsc && node ./script/build --rebuild --config=Debug", @@ -39,7 +40,7 @@ "jsonc": "^2.0.0", "minimist": "^1.2.8", "node-addon-api": "^6.0.0", - "onnx-proto": "^8.0.1" + "protobufjs": "^7.2.4" }, "main": "dist/index.js", "os": [ diff --git a/js/node/test/ort-schema/protobuf/.gitignore b/js/node/test/ort-schema/protobuf/.gitignore new file mode 100644 index 0000000000000..092bb6c1c9fb4 --- /dev/null +++ b/js/node/test/ort-schema/protobuf/.gitignore @@ -0,0 +1,2 @@ +!onnx.js +!onnx.d.ts diff --git a/js/node/test/ort-schema/protobuf/README.md b/js/node/test/ort-schema/protobuf/README.md new file mode 100644 index 0000000000000..f5f52c602f1ad --- /dev/null +++ b/js/node/test/ort-schema/protobuf/README.md @@ -0,0 +1,21 @@ +# ONNX protobuf + +This directory contains generated protobuf definition for onnx: + +- onnx.js +- onnx.d.ts + +These files are generated from [a fork of onnx-proto](https://github.com/fs-eire/onnx-proto/tree/update-v9). + +The ONNX protobuf uses protobufjs@7.2.4, which depends on long@5.2.3, the version contains 2 bugs: + +- type export does not work with commonjs. described in https://github.com/dcodeIO/long.js/pull/124. added a "postinstall" script to fix. +- in the generated typescript declaration file 'onnx.d.ts', the following line: + ```ts + import Long = require("long"); + ``` + need to be replaced to fix type import error: + ```ts + import Long from "long"; + ``` + this replacement is done and code format is also applied to file 'onnx.d.ts'. diff --git a/js/node/test/ort-schema/protobuf/onnx.d.ts b/js/node/test/ort-schema/protobuf/onnx.d.ts new file mode 100644 index 0000000000000..c60264dca2a8d --- /dev/null +++ b/js/node/test/ort-schema/protobuf/onnx.d.ts @@ -0,0 +1,2627 @@ +import Long from 'long'; +import * as $protobuf from 'protobufjs'; + +/** Namespace onnx. */ +export namespace onnx { + + /** Version enum. */ + enum Version { + _START_VERSION = 0, + IR_VERSION_2017_10_10 = 1, + IR_VERSION_2017_10_30 = 2, + IR_VERSION_2017_11_3 = 3, + IR_VERSION_2019_1_22 = 4, + IR_VERSION_2019_3_18 = 5, + IR_VERSION_2019_9_19 = 6, + IR_VERSION_2020_5_8 = 7, + IR_VERSION_2021_7_30 = 8, + IR_VERSION = 9 + } + + /** Properties of an AttributeProto. */ + interface IAttributeProto { + /** AttributeProto name */ + name?: (string|null); + + /** AttributeProto refAttrName */ + refAttrName?: (string|null); + + /** AttributeProto docString */ + docString?: (string|null); + + /** AttributeProto type */ + type?: (onnx.AttributeProto.AttributeType|null); + + /** AttributeProto f */ + f?: (number|null); + + /** AttributeProto i */ + i?: (number|Long|null); + + /** AttributeProto s */ + s?: (Uint8Array|null); + + /** AttributeProto t */ + t?: (onnx.ITensorProto|null); + + /** AttributeProto g */ + g?: (onnx.IGraphProto|null); + + /** AttributeProto sparseTensor */ + sparseTensor?: (onnx.ISparseTensorProto|null); + + /** AttributeProto tp */ + tp?: (onnx.ITypeProto|null); + + /** AttributeProto floats */ + floats?: (number[]|null); + + /** AttributeProto ints */ + ints?: ((number | Long)[]|null); + + /** AttributeProto strings */ + strings?: (Uint8Array[]|null); + + /** AttributeProto tensors */ + tensors?: (onnx.ITensorProto[]|null); + + /** AttributeProto graphs */ + graphs?: (onnx.IGraphProto[]|null); + + /** AttributeProto sparseTensors */ + sparseTensors?: (onnx.ISparseTensorProto[]|null); + + /** AttributeProto typeProtos */ + typeProtos?: (onnx.ITypeProto[]|null); + } + + /** Represents an AttributeProto. */ + class AttributeProto implements IAttributeProto { + /** + * Constructs a new AttributeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IAttributeProto); + + /** AttributeProto name. */ + public name: string; + + /** AttributeProto refAttrName. */ + public refAttrName: string; + + /** AttributeProto docString. */ + public docString: string; + + /** AttributeProto type. */ + public type: onnx.AttributeProto.AttributeType; + + /** AttributeProto f. */ + public f: number; + + /** AttributeProto i. */ + public i: (number|Long); + + /** AttributeProto s. */ + public s: Uint8Array; + + /** AttributeProto t. */ + public t?: (onnx.ITensorProto|null); + + /** AttributeProto g. */ + public g?: (onnx.IGraphProto|null); + + /** AttributeProto sparseTensor. */ + public sparseTensor?: (onnx.ISparseTensorProto|null); + + /** AttributeProto tp. */ + public tp?: (onnx.ITypeProto|null); + + /** AttributeProto floats. */ + public floats: number[]; + + /** AttributeProto ints. */ + public ints: (number|Long)[]; + + /** AttributeProto strings. */ + public strings: Uint8Array[]; + + /** AttributeProto tensors. */ + public tensors: onnx.ITensorProto[]; + + /** AttributeProto graphs. */ + public graphs: onnx.IGraphProto[]; + + /** AttributeProto sparseTensors. */ + public sparseTensors: onnx.ISparseTensorProto[]; + + /** AttributeProto typeProtos. */ + public typeProtos: onnx.ITypeProto[]; + + /** + * Creates a new AttributeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns AttributeProto instance + */ + public static create(properties?: onnx.IAttributeProto): onnx.AttributeProto; + + /** + * Encodes the specified AttributeProto message. Does not implicitly {@link onnx.AttributeProto.verify|verify} + * messages. + * @param message AttributeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IAttributeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified AttributeProto message, length delimited. Does not implicitly {@link + * onnx.AttributeProto.verify|verify} messages. + * @param message AttributeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IAttributeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes an AttributeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.AttributeProto; + + /** + * Decodes an AttributeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.AttributeProto; + + /** + * Verifies an AttributeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates an AttributeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns AttributeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.AttributeProto; + + /** + * Creates a plain object from an AttributeProto message. Also converts values to other types if specified. + * @param message AttributeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.AttributeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this AttributeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for AttributeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace AttributeProto { + + /** AttributeType enum. */ + enum AttributeType { + UNDEFINED = 0, + FLOAT = 1, + INT = 2, + STRING = 3, + TENSOR = 4, + GRAPH = 5, + SPARSE_TENSOR = 11, + TYPE_PROTO = 13, + FLOATS = 6, + INTS = 7, + STRINGS = 8, + TENSORS = 9, + GRAPHS = 10, + SPARSE_TENSORS = 12, + TYPE_PROTOS = 14 + } + } + + /** Properties of a ValueInfoProto. */ + interface IValueInfoProto { + /** ValueInfoProto name */ + name?: (string|null); + + /** ValueInfoProto type */ + type?: (onnx.ITypeProto|null); + + /** ValueInfoProto docString */ + docString?: (string|null); + } + + /** Represents a ValueInfoProto. */ + class ValueInfoProto implements IValueInfoProto { + /** + * Constructs a new ValueInfoProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IValueInfoProto); + + /** ValueInfoProto name. */ + public name: string; + + /** ValueInfoProto type. */ + public type?: (onnx.ITypeProto|null); + + /** ValueInfoProto docString. */ + public docString: string; + + /** + * Creates a new ValueInfoProto instance using the specified properties. + * @param [properties] Properties to set + * @returns ValueInfoProto instance + */ + public static create(properties?: onnx.IValueInfoProto): onnx.ValueInfoProto; + + /** + * Encodes the specified ValueInfoProto message. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} + * messages. + * @param message ValueInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IValueInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified ValueInfoProto message, length delimited. Does not implicitly {@link + * onnx.ValueInfoProto.verify|verify} messages. + * @param message ValueInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IValueInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.ValueInfoProto; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.ValueInfoProto; + + /** + * Verifies a ValueInfoProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a ValueInfoProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns ValueInfoProto + */ + public static fromObject(object: {[k: string]: any}): onnx.ValueInfoProto; + + /** + * Creates a plain object from a ValueInfoProto message. Also converts values to other types if specified. + * @param message ValueInfoProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.ValueInfoProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this ValueInfoProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for ValueInfoProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a NodeProto. */ + interface INodeProto { + /** NodeProto input */ + input?: (string[]|null); + + /** NodeProto output */ + output?: (string[]|null); + + /** NodeProto name */ + name?: (string|null); + + /** NodeProto opType */ + opType?: (string|null); + + /** NodeProto domain */ + domain?: (string|null); + + /** NodeProto attribute */ + attribute?: (onnx.IAttributeProto[]|null); + + /** NodeProto docString */ + docString?: (string|null); + } + + /** Represents a NodeProto. */ + class NodeProto implements INodeProto { + /** + * Constructs a new NodeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.INodeProto); + + /** NodeProto input. */ + public input: string[]; + + /** NodeProto output. */ + public output: string[]; + + /** NodeProto name. */ + public name: string; + + /** NodeProto opType. */ + public opType: string; + + /** NodeProto domain. */ + public domain: string; + + /** NodeProto attribute. */ + public attribute: onnx.IAttributeProto[]; + + /** NodeProto docString. */ + public docString: string; + + /** + * Creates a new NodeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns NodeProto instance + */ + public static create(properties?: onnx.INodeProto): onnx.NodeProto; + + /** + * Encodes the specified NodeProto message. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @param message NodeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.INodeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified NodeProto message, length delimited. Does not implicitly {@link + * onnx.NodeProto.verify|verify} messages. + * @param message NodeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.INodeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a NodeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.NodeProto; + + /** + * Decodes a NodeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.NodeProto; + + /** + * Verifies a NodeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a NodeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns NodeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.NodeProto; + + /** + * Creates a plain object from a NodeProto message. Also converts values to other types if specified. + * @param message NodeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.NodeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this NodeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for NodeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TrainingInfoProto. */ + interface ITrainingInfoProto { + /** TrainingInfoProto initialization */ + initialization?: (onnx.IGraphProto|null); + + /** TrainingInfoProto algorithm */ + algorithm?: (onnx.IGraphProto|null); + + /** TrainingInfoProto initializationBinding */ + initializationBinding?: (onnx.IStringStringEntryProto[]|null); + + /** TrainingInfoProto updateBinding */ + updateBinding?: (onnx.IStringStringEntryProto[]|null); + } + + /** Represents a TrainingInfoProto. */ + class TrainingInfoProto implements ITrainingInfoProto { + /** + * Constructs a new TrainingInfoProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITrainingInfoProto); + + /** TrainingInfoProto initialization. */ + public initialization?: (onnx.IGraphProto|null); + + /** TrainingInfoProto algorithm. */ + public algorithm?: (onnx.IGraphProto|null); + + /** TrainingInfoProto initializationBinding. */ + public initializationBinding: onnx.IStringStringEntryProto[]; + + /** TrainingInfoProto updateBinding. */ + public updateBinding: onnx.IStringStringEntryProto[]; + + /** + * Creates a new TrainingInfoProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TrainingInfoProto instance + */ + public static create(properties?: onnx.ITrainingInfoProto): onnx.TrainingInfoProto; + + /** + * Encodes the specified TrainingInfoProto message. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} + * messages. + * @param message TrainingInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITrainingInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TrainingInfoProto message, length delimited. Does not implicitly {@link + * onnx.TrainingInfoProto.verify|verify} messages. + * @param message TrainingInfoProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITrainingInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TrainingInfoProto; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TrainingInfoProto; + + /** + * Verifies a TrainingInfoProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TrainingInfoProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TrainingInfoProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TrainingInfoProto; + + /** + * Creates a plain object from a TrainingInfoProto message. Also converts values to other types if specified. + * @param message TrainingInfoProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TrainingInfoProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TrainingInfoProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TrainingInfoProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a ModelProto. */ + interface IModelProto { + /** ModelProto irVersion */ + irVersion?: (number|Long|null); + + /** ModelProto opsetImport */ + opsetImport?: (onnx.IOperatorSetIdProto[]|null); + + /** ModelProto producerName */ + producerName?: (string|null); + + /** ModelProto producerVersion */ + producerVersion?: (string|null); + + /** ModelProto domain */ + domain?: (string|null); + + /** ModelProto modelVersion */ + modelVersion?: (number|Long|null); + + /** ModelProto docString */ + docString?: (string|null); + + /** ModelProto graph */ + graph?: (onnx.IGraphProto|null); + + /** ModelProto metadataProps */ + metadataProps?: (onnx.IStringStringEntryProto[]|null); + + /** ModelProto trainingInfo */ + trainingInfo?: (onnx.ITrainingInfoProto[]|null); + + /** ModelProto functions */ + functions?: (onnx.IFunctionProto[]|null); + } + + /** Represents a ModelProto. */ + class ModelProto implements IModelProto { + /** + * Constructs a new ModelProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IModelProto); + + /** ModelProto irVersion. */ + public irVersion: (number|Long); + + /** ModelProto opsetImport. */ + public opsetImport: onnx.IOperatorSetIdProto[]; + + /** ModelProto producerName. */ + public producerName: string; + + /** ModelProto producerVersion. */ + public producerVersion: string; + + /** ModelProto domain. */ + public domain: string; + + /** ModelProto modelVersion. */ + public modelVersion: (number|Long); + + /** ModelProto docString. */ + public docString: string; + + /** ModelProto graph. */ + public graph?: (onnx.IGraphProto|null); + + /** ModelProto metadataProps. */ + public metadataProps: onnx.IStringStringEntryProto[]; + + /** ModelProto trainingInfo. */ + public trainingInfo: onnx.ITrainingInfoProto[]; + + /** ModelProto functions. */ + public functions: onnx.IFunctionProto[]; + + /** + * Creates a new ModelProto instance using the specified properties. + * @param [properties] Properties to set + * @returns ModelProto instance + */ + public static create(properties?: onnx.IModelProto): onnx.ModelProto; + + /** + * Encodes the specified ModelProto message. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @param message ModelProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IModelProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified ModelProto message, length delimited. Does not implicitly {@link + * onnx.ModelProto.verify|verify} messages. + * @param message ModelProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IModelProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a ModelProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.ModelProto; + + /** + * Decodes a ModelProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.ModelProto; + + /** + * Verifies a ModelProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a ModelProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns ModelProto + */ + public static fromObject(object: {[k: string]: any}): onnx.ModelProto; + + /** + * Creates a plain object from a ModelProto message. Also converts values to other types if specified. + * @param message ModelProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.ModelProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this ModelProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for ModelProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a StringStringEntryProto. */ + interface IStringStringEntryProto { + /** StringStringEntryProto key */ + key?: (string|null); + + /** StringStringEntryProto value */ + value?: (string|null); + } + + /** Represents a StringStringEntryProto. */ + class StringStringEntryProto implements IStringStringEntryProto { + /** + * Constructs a new StringStringEntryProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IStringStringEntryProto); + + /** StringStringEntryProto key. */ + public key: string; + + /** StringStringEntryProto value. */ + public value: string; + + /** + * Creates a new StringStringEntryProto instance using the specified properties. + * @param [properties] Properties to set + * @returns StringStringEntryProto instance + */ + public static create(properties?: onnx.IStringStringEntryProto): onnx.StringStringEntryProto; + + /** + * Encodes the specified StringStringEntryProto message. Does not implicitly {@link + * onnx.StringStringEntryProto.verify|verify} messages. + * @param message StringStringEntryProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IStringStringEntryProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified StringStringEntryProto message, length delimited. Does not implicitly {@link + * onnx.StringStringEntryProto.verify|verify} messages. + * @param message StringStringEntryProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IStringStringEntryProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.StringStringEntryProto; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.StringStringEntryProto; + + /** + * Verifies a StringStringEntryProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a StringStringEntryProto message from a plain object. Also converts values to their respective internal + * types. + * @param object Plain object + * @returns StringStringEntryProto + */ + public static fromObject(object: {[k: string]: any}): onnx.StringStringEntryProto; + + /** + * Creates a plain object from a StringStringEntryProto message. Also converts values to other types if specified. + * @param message StringStringEntryProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.StringStringEntryProto, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this StringStringEntryProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for StringStringEntryProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TensorAnnotation. */ + interface ITensorAnnotation { + /** TensorAnnotation tensorName */ + tensorName?: (string|null); + + /** TensorAnnotation quantParameterTensorNames */ + quantParameterTensorNames?: (onnx.IStringStringEntryProto[]|null); + } + + /** Represents a TensorAnnotation. */ + class TensorAnnotation implements ITensorAnnotation { + /** + * Constructs a new TensorAnnotation. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITensorAnnotation); + + /** TensorAnnotation tensorName. */ + public tensorName: string; + + /** TensorAnnotation quantParameterTensorNames. */ + public quantParameterTensorNames: onnx.IStringStringEntryProto[]; + + /** + * Creates a new TensorAnnotation instance using the specified properties. + * @param [properties] Properties to set + * @returns TensorAnnotation instance + */ + public static create(properties?: onnx.ITensorAnnotation): onnx.TensorAnnotation; + + /** + * Encodes the specified TensorAnnotation message. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} + * messages. + * @param message TensorAnnotation message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITensorAnnotation, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TensorAnnotation message, length delimited. Does not implicitly {@link + * onnx.TensorAnnotation.verify|verify} messages. + * @param message TensorAnnotation message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITensorAnnotation, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorAnnotation; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorAnnotation; + + /** + * Verifies a TensorAnnotation message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TensorAnnotation message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TensorAnnotation + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorAnnotation; + + /** + * Creates a plain object from a TensorAnnotation message. Also converts values to other types if specified. + * @param message TensorAnnotation + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorAnnotation, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TensorAnnotation to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TensorAnnotation + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a GraphProto. */ + interface IGraphProto { + /** GraphProto node */ + node?: (onnx.INodeProto[]|null); + + /** GraphProto name */ + name?: (string|null); + + /** GraphProto initializer */ + initializer?: (onnx.ITensorProto[]|null); + + /** GraphProto sparseInitializer */ + sparseInitializer?: (onnx.ISparseTensorProto[]|null); + + /** GraphProto docString */ + docString?: (string|null); + + /** GraphProto input */ + input?: (onnx.IValueInfoProto[]|null); + + /** GraphProto output */ + output?: (onnx.IValueInfoProto[]|null); + + /** GraphProto valueInfo */ + valueInfo?: (onnx.IValueInfoProto[]|null); + + /** GraphProto quantizationAnnotation */ + quantizationAnnotation?: (onnx.ITensorAnnotation[]|null); + } + + /** Represents a GraphProto. */ + class GraphProto implements IGraphProto { + /** + * Constructs a new GraphProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IGraphProto); + + /** GraphProto node. */ + public node: onnx.INodeProto[]; + + /** GraphProto name. */ + public name: string; + + /** GraphProto initializer. */ + public initializer: onnx.ITensorProto[]; + + /** GraphProto sparseInitializer. */ + public sparseInitializer: onnx.ISparseTensorProto[]; + + /** GraphProto docString. */ + public docString: string; + + /** GraphProto input. */ + public input: onnx.IValueInfoProto[]; + + /** GraphProto output. */ + public output: onnx.IValueInfoProto[]; + + /** GraphProto valueInfo. */ + public valueInfo: onnx.IValueInfoProto[]; + + /** GraphProto quantizationAnnotation. */ + public quantizationAnnotation: onnx.ITensorAnnotation[]; + + /** + * Creates a new GraphProto instance using the specified properties. + * @param [properties] Properties to set + * @returns GraphProto instance + */ + public static create(properties?: onnx.IGraphProto): onnx.GraphProto; + + /** + * Encodes the specified GraphProto message. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @param message GraphProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IGraphProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified GraphProto message, length delimited. Does not implicitly {@link + * onnx.GraphProto.verify|verify} messages. + * @param message GraphProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IGraphProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a GraphProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.GraphProto; + + /** + * Decodes a GraphProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.GraphProto; + + /** + * Verifies a GraphProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a GraphProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns GraphProto + */ + public static fromObject(object: {[k: string]: any}): onnx.GraphProto; + + /** + * Creates a plain object from a GraphProto message. Also converts values to other types if specified. + * @param message GraphProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.GraphProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this GraphProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for GraphProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TensorProto. */ + interface ITensorProto { + /** TensorProto dims */ + dims?: ((number | Long)[]|null); + + /** TensorProto dataType */ + dataType?: (number|null); + + /** TensorProto segment */ + segment?: (onnx.TensorProto.ISegment|null); + + /** TensorProto floatData */ + floatData?: (number[]|null); + + /** TensorProto int32Data */ + int32Data?: (number[]|null); + + /** TensorProto stringData */ + stringData?: (Uint8Array[]|null); + + /** TensorProto int64Data */ + int64Data?: ((number | Long)[]|null); + + /** TensorProto name */ + name?: (string|null); + + /** TensorProto docString */ + docString?: (string|null); + + /** TensorProto rawData */ + rawData?: (Uint8Array|null); + + /** TensorProto externalData */ + externalData?: (onnx.IStringStringEntryProto[]|null); + + /** TensorProto dataLocation */ + dataLocation?: (onnx.TensorProto.DataLocation|null); + + /** TensorProto doubleData */ + doubleData?: (number[]|null); + + /** TensorProto uint64Data */ + uint64Data?: ((number | Long)[]|null); + } + + /** Represents a TensorProto. */ + class TensorProto implements ITensorProto { + /** + * Constructs a new TensorProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITensorProto); + + /** TensorProto dims. */ + public dims: (number|Long)[]; + + /** TensorProto dataType. */ + public dataType: number; + + /** TensorProto segment. */ + public segment?: (onnx.TensorProto.ISegment|null); + + /** TensorProto floatData. */ + public floatData: number[]; + + /** TensorProto int32Data. */ + public int32Data: number[]; + + /** TensorProto stringData. */ + public stringData: Uint8Array[]; + + /** TensorProto int64Data. */ + public int64Data: (number|Long)[]; + + /** TensorProto name. */ + public name: string; + + /** TensorProto docString. */ + public docString: string; + + /** TensorProto rawData. */ + public rawData: Uint8Array; + + /** TensorProto externalData. */ + public externalData: onnx.IStringStringEntryProto[]; + + /** TensorProto dataLocation. */ + public dataLocation: onnx.TensorProto.DataLocation; + + /** TensorProto doubleData. */ + public doubleData: number[]; + + /** TensorProto uint64Data. */ + public uint64Data: (number|Long)[]; + + /** + * Creates a new TensorProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TensorProto instance + */ + public static create(properties?: onnx.ITensorProto): onnx.TensorProto; + + /** + * Encodes the specified TensorProto message. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @param message TensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TensorProto message, length delimited. Does not implicitly {@link + * onnx.TensorProto.verify|verify} messages. + * @param message TensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TensorProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorProto; + + /** + * Decodes a TensorProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorProto; + + /** + * Verifies a TensorProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TensorProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TensorProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorProto; + + /** + * Creates a plain object from a TensorProto message. Also converts values to other types if specified. + * @param message TensorProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TensorProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TensorProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace TensorProto { + + /** DataType enum. */ + enum DataType { + UNDEFINED = 0, + FLOAT = 1, + UINT8 = 2, + INT8 = 3, + UINT16 = 4, + INT16 = 5, + INT32 = 6, + INT64 = 7, + STRING = 8, + BOOL = 9, + FLOAT16 = 10, + DOUBLE = 11, + UINT32 = 12, + UINT64 = 13, + COMPLEX64 = 14, + COMPLEX128 = 15, + BFLOAT16 = 16, + FLOAT8E4M3FN = 17, + FLOAT8E4M3FNUZ = 18, + FLOAT8E5M2 = 19, + FLOAT8E5M2FNUZ = 20 + } + + /** Properties of a Segment. */ + interface ISegment { + /** Segment begin */ + begin?: (number|Long|null); + + /** Segment end */ + end?: (number|Long|null); + } + + /** Represents a Segment. */ + class Segment implements ISegment { + /** + * Constructs a new Segment. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TensorProto.ISegment); + + /** Segment begin. */ + public begin: (number|Long); + + /** Segment end. */ + public end: (number|Long); + + /** + * Creates a new Segment instance using the specified properties. + * @param [properties] Properties to set + * @returns Segment instance + */ + public static create(properties?: onnx.TensorProto.ISegment): onnx.TensorProto.Segment; + + /** + * Encodes the specified Segment message. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} + * messages. + * @param message Segment message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TensorProto.ISegment, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Segment message, length delimited. Does not implicitly {@link + * onnx.TensorProto.Segment.verify|verify} messages. + * @param message Segment message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TensorProto.ISegment, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Segment message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorProto.Segment; + + /** + * Decodes a Segment message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorProto.Segment; + + /** + * Verifies a Segment message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Segment message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Segment + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorProto.Segment; + + /** + * Creates a plain object from a Segment message. Also converts values to other types if specified. + * @param message Segment + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorProto.Segment, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Segment to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Segment + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** DataLocation enum. */ + enum DataLocation { DEFAULT = 0, EXTERNAL = 1 } + } + + /** Properties of a SparseTensorProto. */ + interface ISparseTensorProto { + /** SparseTensorProto values */ + values?: (onnx.ITensorProto|null); + + /** SparseTensorProto indices */ + indices?: (onnx.ITensorProto|null); + + /** SparseTensorProto dims */ + dims?: ((number | Long)[]|null); + } + + /** Represents a SparseTensorProto. */ + class SparseTensorProto implements ISparseTensorProto { + /** + * Constructs a new SparseTensorProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ISparseTensorProto); + + /** SparseTensorProto values. */ + public values?: (onnx.ITensorProto|null); + + /** SparseTensorProto indices. */ + public indices?: (onnx.ITensorProto|null); + + /** SparseTensorProto dims. */ + public dims: (number|Long)[]; + + /** + * Creates a new SparseTensorProto instance using the specified properties. + * @param [properties] Properties to set + * @returns SparseTensorProto instance + */ + public static create(properties?: onnx.ISparseTensorProto): onnx.SparseTensorProto; + + /** + * Encodes the specified SparseTensorProto message. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} + * messages. + * @param message SparseTensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ISparseTensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified SparseTensorProto message, length delimited. Does not implicitly {@link + * onnx.SparseTensorProto.verify|verify} messages. + * @param message SparseTensorProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ISparseTensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.SparseTensorProto; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.SparseTensorProto; + + /** + * Verifies a SparseTensorProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a SparseTensorProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns SparseTensorProto + */ + public static fromObject(object: {[k: string]: any}): onnx.SparseTensorProto; + + /** + * Creates a plain object from a SparseTensorProto message. Also converts values to other types if specified. + * @param message SparseTensorProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.SparseTensorProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this SparseTensorProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for SparseTensorProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a TensorShapeProto. */ + interface ITensorShapeProto { + /** TensorShapeProto dim */ + dim?: (onnx.TensorShapeProto.IDimension[]|null); + } + + /** Represents a TensorShapeProto. */ + class TensorShapeProto implements ITensorShapeProto { + /** + * Constructs a new TensorShapeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITensorShapeProto); + + /** TensorShapeProto dim. */ + public dim: onnx.TensorShapeProto.IDimension[]; + + /** + * Creates a new TensorShapeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TensorShapeProto instance + */ + public static create(properties?: onnx.ITensorShapeProto): onnx.TensorShapeProto; + + /** + * Encodes the specified TensorShapeProto message. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} + * messages. + * @param message TensorShapeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITensorShapeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TensorShapeProto message, length delimited. Does not implicitly {@link + * onnx.TensorShapeProto.verify|verify} messages. + * @param message TensorShapeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITensorShapeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorShapeProto; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorShapeProto; + + /** + * Verifies a TensorShapeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TensorShapeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TensorShapeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorShapeProto; + + /** + * Creates a plain object from a TensorShapeProto message. Also converts values to other types if specified. + * @param message TensorShapeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorShapeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TensorShapeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TensorShapeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace TensorShapeProto { + + /** Properties of a Dimension. */ + interface IDimension { + /** Dimension dimValue */ + dimValue?: (number|Long|null); + + /** Dimension dimParam */ + dimParam?: (string|null); + + /** Dimension denotation */ + denotation?: (string|null); + } + + /** Represents a Dimension. */ + class Dimension implements IDimension { + /** + * Constructs a new Dimension. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TensorShapeProto.IDimension); + + /** Dimension dimValue. */ + public dimValue?: (number|Long|null); + + /** Dimension dimParam. */ + public dimParam?: (string|null); + + /** Dimension denotation. */ + public denotation: string; + + /** Dimension value. */ + public value?: ('dimValue'|'dimParam'); + + /** + * Creates a new Dimension instance using the specified properties. + * @param [properties] Properties to set + * @returns Dimension instance + */ + public static create(properties?: onnx.TensorShapeProto.IDimension): onnx.TensorShapeProto.Dimension; + + /** + * Encodes the specified Dimension message. Does not implicitly {@link + * onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @param message Dimension message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TensorShapeProto.IDimension, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Dimension message, length delimited. Does not implicitly {@link + * onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @param message Dimension message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TensorShapeProto.IDimension, writer?: $protobuf.Writer): + $protobuf.Writer; + + /** + * Decodes a Dimension message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorShapeProto.Dimension; + + /** + * Decodes a Dimension message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorShapeProto.Dimension; + + /** + * Verifies a Dimension message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Dimension message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Dimension + */ + public static fromObject(object: {[k: string]: any}): onnx.TensorShapeProto.Dimension; + + /** + * Creates a plain object from a Dimension message. Also converts values to other types if specified. + * @param message Dimension + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TensorShapeProto.Dimension, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Dimension to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Dimension + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + } + + /** Properties of a TypeProto. */ + interface ITypeProto { + /** TypeProto tensorType */ + tensorType?: (onnx.TypeProto.ITensor|null); + + /** TypeProto sequenceType */ + sequenceType?: (onnx.TypeProto.ISequence|null); + + /** TypeProto mapType */ + mapType?: (onnx.TypeProto.IMap|null); + + /** TypeProto optionalType */ + optionalType?: (onnx.TypeProto.IOptional|null); + + /** TypeProto sparseTensorType */ + sparseTensorType?: (onnx.TypeProto.ISparseTensor|null); + + /** TypeProto denotation */ + denotation?: (string|null); + } + + /** Represents a TypeProto. */ + class TypeProto implements ITypeProto { + /** + * Constructs a new TypeProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.ITypeProto); + + /** TypeProto tensorType. */ + public tensorType?: (onnx.TypeProto.ITensor|null); + + /** TypeProto sequenceType. */ + public sequenceType?: (onnx.TypeProto.ISequence|null); + + /** TypeProto mapType. */ + public mapType?: (onnx.TypeProto.IMap|null); + + /** TypeProto optionalType. */ + public optionalType?: (onnx.TypeProto.IOptional|null); + + /** TypeProto sparseTensorType. */ + public sparseTensorType?: (onnx.TypeProto.ISparseTensor|null); + + /** TypeProto denotation. */ + public denotation: string; + + /** TypeProto value. */ + public value?: ('tensorType'|'sequenceType'|'mapType'|'optionalType'|'sparseTensorType'); + + /** + * Creates a new TypeProto instance using the specified properties. + * @param [properties] Properties to set + * @returns TypeProto instance + */ + public static create(properties?: onnx.ITypeProto): onnx.TypeProto; + + /** + * Encodes the specified TypeProto message. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @param message TypeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.ITypeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified TypeProto message, length delimited. Does not implicitly {@link + * onnx.TypeProto.verify|verify} messages. + * @param message TypeProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.ITypeProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a TypeProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto; + + /** + * Decodes a TypeProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto; + + /** + * Verifies a TypeProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a TypeProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns TypeProto + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto; + + /** + * Creates a plain object from a TypeProto message. Also converts values to other types if specified. + * @param message TypeProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this TypeProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for TypeProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + namespace TypeProto { + + /** Properties of a Tensor. */ + interface ITensor { + /** Tensor elemType */ + elemType?: (number|null); + + /** Tensor shape */ + shape?: (onnx.ITensorShapeProto|null); + } + + /** Represents a Tensor. */ + class Tensor implements ITensor { + /** + * Constructs a new Tensor. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.ITensor); + + /** Tensor elemType. */ + public elemType: number; + + /** Tensor shape. */ + public shape?: (onnx.ITensorShapeProto|null); + + /** + * Creates a new Tensor instance using the specified properties. + * @param [properties] Properties to set + * @returns Tensor instance + */ + public static create(properties?: onnx.TypeProto.ITensor): onnx.TypeProto.Tensor; + + /** + * Encodes the specified Tensor message. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @param message Tensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.ITensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Tensor message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Tensor.verify|verify} messages. + * @param message Tensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.ITensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Tensor message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Tensor; + + /** + * Decodes a Tensor message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Tensor; + + /** + * Verifies a Tensor message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Tensor message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Tensor + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Tensor; + + /** + * Creates a plain object from a Tensor message. Also converts values to other types if specified. + * @param message Tensor + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Tensor, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Tensor to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Tensor + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a Sequence. */ + interface ISequence { + /** Sequence elemType */ + elemType?: (onnx.ITypeProto|null); + } + + /** Represents a Sequence. */ + class Sequence implements ISequence { + /** + * Constructs a new Sequence. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.ISequence); + + /** Sequence elemType. */ + public elemType?: (onnx.ITypeProto|null); + + /** + * Creates a new Sequence instance using the specified properties. + * @param [properties] Properties to set + * @returns Sequence instance + */ + public static create(properties?: onnx.TypeProto.ISequence): onnx.TypeProto.Sequence; + + /** + * Encodes the specified Sequence message. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} + * messages. + * @param message Sequence message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.ISequence, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Sequence message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Sequence.verify|verify} messages. + * @param message Sequence message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.ISequence, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Sequence message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Sequence; + + /** + * Decodes a Sequence message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Sequence; + + /** + * Verifies a Sequence message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Sequence message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Sequence + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Sequence; + + /** + * Creates a plain object from a Sequence message. Also converts values to other types if specified. + * @param message Sequence + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Sequence, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Sequence to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Sequence + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a Map. */ + interface IMap { + /** Map keyType */ + keyType?: (number|null); + + /** Map valueType */ + valueType?: (onnx.ITypeProto|null); + } + + /** Represents a Map. */ + class Map implements IMap { + /** + * Constructs a new Map. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.IMap); + + /** Map keyType. */ + public keyType: number; + + /** Map valueType. */ + public valueType?: (onnx.ITypeProto|null); + + /** + * Creates a new Map instance using the specified properties. + * @param [properties] Properties to set + * @returns Map instance + */ + public static create(properties?: onnx.TypeProto.IMap): onnx.TypeProto.Map; + + /** + * Encodes the specified Map message. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @param message Map message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.IMap, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Map message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Map.verify|verify} messages. + * @param message Map message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.IMap, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a Map message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Map; + + /** + * Decodes a Map message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Map; + + /** + * Verifies a Map message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a Map message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Map + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Map; + + /** + * Creates a plain object from a Map message. Also converts values to other types if specified. + * @param message Map + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Map, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this Map to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Map + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of an Optional. */ + interface IOptional { + /** Optional elemType */ + elemType?: (onnx.ITypeProto|null); + } + + /** Represents an Optional. */ + class Optional implements IOptional { + /** + * Constructs a new Optional. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.IOptional); + + /** Optional elemType. */ + public elemType?: (onnx.ITypeProto|null); + + /** + * Creates a new Optional instance using the specified properties. + * @param [properties] Properties to set + * @returns Optional instance + */ + public static create(properties?: onnx.TypeProto.IOptional): onnx.TypeProto.Optional; + + /** + * Encodes the specified Optional message. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} + * messages. + * @param message Optional message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.IOptional, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified Optional message, length delimited. Does not implicitly {@link + * onnx.TypeProto.Optional.verify|verify} messages. + * @param message Optional message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.IOptional, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes an Optional message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Optional; + + /** + * Decodes an Optional message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Optional; + + /** + * Verifies an Optional message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates an Optional message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns Optional + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Optional; + + /** + * Creates a plain object from an Optional message. Also converts values to other types if specified. + * @param message Optional + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.Optional, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this Optional to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for Optional + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** Properties of a SparseTensor. */ + interface ISparseTensor { + /** SparseTensor elemType */ + elemType?: (number|null); + + /** SparseTensor shape */ + shape?: (onnx.ITensorShapeProto|null); + } + + /** Represents a SparseTensor. */ + class SparseTensor implements ISparseTensor { + /** + * Constructs a new SparseTensor. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.TypeProto.ISparseTensor); + + /** SparseTensor elemType. */ + public elemType: number; + + /** SparseTensor shape. */ + public shape?: (onnx.ITensorShapeProto|null); + + /** + * Creates a new SparseTensor instance using the specified properties. + * @param [properties] Properties to set + * @returns SparseTensor instance + */ + public static create(properties?: onnx.TypeProto.ISparseTensor): onnx.TypeProto.SparseTensor; + + /** + * Encodes the specified SparseTensor message. Does not implicitly {@link + * onnx.TypeProto.SparseTensor.verify|verify} messages. + * @param message SparseTensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.TypeProto.ISparseTensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified SparseTensor message, length delimited. Does not implicitly {@link + * onnx.TypeProto.SparseTensor.verify|verify} messages. + * @param message SparseTensor message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.TypeProto.ISparseTensor, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a SparseTensor message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.SparseTensor; + + /** + * Decodes a SparseTensor message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.SparseTensor; + + /** + * Verifies a SparseTensor message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a SparseTensor message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns SparseTensor + */ + public static fromObject(object: {[k: string]: any}): onnx.TypeProto.SparseTensor; + + /** + * Creates a plain object from a SparseTensor message. Also converts values to other types if specified. + * @param message SparseTensor + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.TypeProto.SparseTensor, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this SparseTensor to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for SparseTensor + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + } + + /** Properties of an OperatorSetIdProto. */ + interface IOperatorSetIdProto { + /** OperatorSetIdProto domain */ + domain?: (string|null); + + /** OperatorSetIdProto version */ + version?: (number|Long|null); + } + + /** Represents an OperatorSetIdProto. */ + class OperatorSetIdProto implements IOperatorSetIdProto { + /** + * Constructs a new OperatorSetIdProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IOperatorSetIdProto); + + /** OperatorSetIdProto domain. */ + public domain: string; + + /** OperatorSetIdProto version. */ + public version: (number|Long); + + /** + * Creates a new OperatorSetIdProto instance using the specified properties. + * @param [properties] Properties to set + * @returns OperatorSetIdProto instance + */ + public static create(properties?: onnx.IOperatorSetIdProto): onnx.OperatorSetIdProto; + + /** + * Encodes the specified OperatorSetIdProto message. Does not implicitly {@link + * onnx.OperatorSetIdProto.verify|verify} messages. + * @param message OperatorSetIdProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IOperatorSetIdProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified OperatorSetIdProto message, length delimited. Does not implicitly {@link + * onnx.OperatorSetIdProto.verify|verify} messages. + * @param message OperatorSetIdProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IOperatorSetIdProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.OperatorSetIdProto; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.OperatorSetIdProto; + + /** + * Verifies an OperatorSetIdProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates an OperatorSetIdProto message from a plain object. Also converts values to their respective internal + * types. + * @param object Plain object + * @returns OperatorSetIdProto + */ + public static fromObject(object: {[k: string]: any}): onnx.OperatorSetIdProto; + + /** + * Creates a plain object from an OperatorSetIdProto message. Also converts values to other types if specified. + * @param message OperatorSetIdProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.OperatorSetIdProto, options?: $protobuf.IConversionOptions): + {[k: string]: any}; + + /** + * Converts this OperatorSetIdProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for OperatorSetIdProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } + + /** OperatorStatus enum. */ + enum OperatorStatus { EXPERIMENTAL = 0, STABLE = 1 } + + /** Properties of a FunctionProto. */ + interface IFunctionProto { + /** FunctionProto name */ + name?: (string|null); + + /** FunctionProto input */ + input?: (string[]|null); + + /** FunctionProto output */ + output?: (string[]|null); + + /** FunctionProto attribute */ + attribute?: (string[]|null); + + /** FunctionProto attributeProto */ + attributeProto?: (onnx.IAttributeProto[]|null); + + /** FunctionProto node */ + node?: (onnx.INodeProto[]|null); + + /** FunctionProto docString */ + docString?: (string|null); + + /** FunctionProto opsetImport */ + opsetImport?: (onnx.IOperatorSetIdProto[]|null); + + /** FunctionProto domain */ + domain?: (string|null); + } + + /** Represents a FunctionProto. */ + class FunctionProto implements IFunctionProto { + /** + * Constructs a new FunctionProto. + * @param [properties] Properties to set + */ + constructor(properties?: onnx.IFunctionProto); + + /** FunctionProto name. */ + public name: string; + + /** FunctionProto input. */ + public input: string[]; + + /** FunctionProto output. */ + public output: string[]; + + /** FunctionProto attribute. */ + public attribute: string[]; + + /** FunctionProto attributeProto. */ + public attributeProto: onnx.IAttributeProto[]; + + /** FunctionProto node. */ + public node: onnx.INodeProto[]; + + /** FunctionProto docString. */ + public docString: string; + + /** FunctionProto opsetImport. */ + public opsetImport: onnx.IOperatorSetIdProto[]; + + /** FunctionProto domain. */ + public domain: string; + + /** + * Creates a new FunctionProto instance using the specified properties. + * @param [properties] Properties to set + * @returns FunctionProto instance + */ + public static create(properties?: onnx.IFunctionProto): onnx.FunctionProto; + + /** + * Encodes the specified FunctionProto message. Does not implicitly {@link onnx.FunctionProto.verify|verify} + * messages. + * @param message FunctionProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encode(message: onnx.IFunctionProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Encodes the specified FunctionProto message, length delimited. Does not implicitly {@link + * onnx.FunctionProto.verify|verify} messages. + * @param message FunctionProto message or plain object to encode + * @param [writer] Writer to encode to + * @returns Writer + */ + public static encodeDelimited(message: onnx.IFunctionProto, writer?: $protobuf.Writer): $protobuf.Writer; + + /** + * Decodes a FunctionProto message from the specified reader or buffer. + * @param reader Reader or buffer to decode from + * @param [length] Message length if known beforehand + * @returns FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.FunctionProto; + + /** + * Decodes a FunctionProto message from the specified reader or buffer, length delimited. + * @param reader Reader or buffer to decode from + * @returns FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.FunctionProto; + + /** + * Verifies a FunctionProto message. + * @param message Plain object to verify + * @returns `null` if valid, otherwise the reason why it is not + */ + public static verify(message: {[k: string]: any}): (string|null); + + /** + * Creates a FunctionProto message from a plain object. Also converts values to their respective internal types. + * @param object Plain object + * @returns FunctionProto + */ + public static fromObject(object: {[k: string]: any}): onnx.FunctionProto; + + /** + * Creates a plain object from a FunctionProto message. Also converts values to other types if specified. + * @param message FunctionProto + * @param [options] Conversion options + * @returns Plain object + */ + public static toObject(message: onnx.FunctionProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + + /** + * Converts this FunctionProto to JSON. + * @returns JSON object + */ + public toJSON(): {[k: string]: any}; + + /** + * Gets the default type url for FunctionProto + * @param [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns The default type url + */ + public static getTypeUrl(typeUrlPrefix?: string): string; + } +} diff --git a/js/node/test/ort-schema/protobuf/onnx.js b/js/node/test/ort-schema/protobuf/onnx.js new file mode 100644 index 0000000000000..681855132d4e8 --- /dev/null +++ b/js/node/test/ort-schema/protobuf/onnx.js @@ -0,0 +1,7658 @@ +/*eslint-disable block-scoped-var, id-length, no-control-regex, no-magic-numbers, no-prototype-builtins, no-redeclare, no-shadow, no-var, sort-vars*/ +"use strict"; + +var $protobuf = require("protobufjs/minimal"); + +// Common aliases +var $Reader = $protobuf.Reader, $Writer = $protobuf.Writer, $util = $protobuf.util; + +// Exported root namespace +var $root = $protobuf.roots["default"] || ($protobuf.roots["default"] = {}); + +$root.onnx = (function() { + + /** + * Namespace onnx. + * @exports onnx + * @namespace + */ + var onnx = {}; + + /** + * Version enum. + * @name onnx.Version + * @enum {number} + * @property {number} _START_VERSION=0 _START_VERSION value + * @property {number} IR_VERSION_2017_10_10=1 IR_VERSION_2017_10_10 value + * @property {number} IR_VERSION_2017_10_30=2 IR_VERSION_2017_10_30 value + * @property {number} IR_VERSION_2017_11_3=3 IR_VERSION_2017_11_3 value + * @property {number} IR_VERSION_2019_1_22=4 IR_VERSION_2019_1_22 value + * @property {number} IR_VERSION_2019_3_18=5 IR_VERSION_2019_3_18 value + * @property {number} IR_VERSION_2019_9_19=6 IR_VERSION_2019_9_19 value + * @property {number} IR_VERSION_2020_5_8=7 IR_VERSION_2020_5_8 value + * @property {number} IR_VERSION_2021_7_30=8 IR_VERSION_2021_7_30 value + * @property {number} IR_VERSION=9 IR_VERSION value + */ + onnx.Version = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "_START_VERSION"] = 0; + values[valuesById[1] = "IR_VERSION_2017_10_10"] = 1; + values[valuesById[2] = "IR_VERSION_2017_10_30"] = 2; + values[valuesById[3] = "IR_VERSION_2017_11_3"] = 3; + values[valuesById[4] = "IR_VERSION_2019_1_22"] = 4; + values[valuesById[5] = "IR_VERSION_2019_3_18"] = 5; + values[valuesById[6] = "IR_VERSION_2019_9_19"] = 6; + values[valuesById[7] = "IR_VERSION_2020_5_8"] = 7; + values[valuesById[8] = "IR_VERSION_2021_7_30"] = 8; + values[valuesById[9] = "IR_VERSION"] = 9; + return values; + })(); + + onnx.AttributeProto = (function() { + + /** + * Properties of an AttributeProto. + * @memberof onnx + * @interface IAttributeProto + * @property {string|null} [name] AttributeProto name + * @property {string|null} [refAttrName] AttributeProto refAttrName + * @property {string|null} [docString] AttributeProto docString + * @property {onnx.AttributeProto.AttributeType|null} [type] AttributeProto type + * @property {number|null} [f] AttributeProto f + * @property {number|Long|null} [i] AttributeProto i + * @property {Uint8Array|null} [s] AttributeProto s + * @property {onnx.ITensorProto|null} [t] AttributeProto t + * @property {onnx.IGraphProto|null} [g] AttributeProto g + * @property {onnx.ISparseTensorProto|null} [sparseTensor] AttributeProto sparseTensor + * @property {onnx.ITypeProto|null} [tp] AttributeProto tp + * @property {Array.|null} [floats] AttributeProto floats + * @property {Array.|null} [ints] AttributeProto ints + * @property {Array.|null} [strings] AttributeProto strings + * @property {Array.|null} [tensors] AttributeProto tensors + * @property {Array.|null} [graphs] AttributeProto graphs + * @property {Array.|null} [sparseTensors] AttributeProto sparseTensors + * @property {Array.|null} [typeProtos] AttributeProto typeProtos + */ + + /** + * Constructs a new AttributeProto. + * @memberof onnx + * @classdesc Represents an AttributeProto. + * @implements IAttributeProto + * @constructor + * @param {onnx.IAttributeProto=} [properties] Properties to set + */ + function AttributeProto(properties) { + this.floats = []; + this.ints = []; + this.strings = []; + this.tensors = []; + this.graphs = []; + this.sparseTensors = []; + this.typeProtos = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * AttributeProto name. + * @member {string} name + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.name = ""; + + /** + * AttributeProto refAttrName. + * @member {string} refAttrName + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.refAttrName = ""; + + /** + * AttributeProto docString. + * @member {string} docString + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.docString = ""; + + /** + * AttributeProto type. + * @member {onnx.AttributeProto.AttributeType} type + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.type = 0; + + /** + * AttributeProto f. + * @member {number} f + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.f = 0; + + /** + * AttributeProto i. + * @member {number|Long} i + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.i = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * AttributeProto s. + * @member {Uint8Array} s + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.s = $util.newBuffer([]); + + /** + * AttributeProto t. + * @member {onnx.ITensorProto|null|undefined} t + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.t = null; + + /** + * AttributeProto g. + * @member {onnx.IGraphProto|null|undefined} g + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.g = null; + + /** + * AttributeProto sparseTensor. + * @member {onnx.ISparseTensorProto|null|undefined} sparseTensor + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.sparseTensor = null; + + /** + * AttributeProto tp. + * @member {onnx.ITypeProto|null|undefined} tp + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.tp = null; + + /** + * AttributeProto floats. + * @member {Array.} floats + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.floats = $util.emptyArray; + + /** + * AttributeProto ints. + * @member {Array.} ints + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.ints = $util.emptyArray; + + /** + * AttributeProto strings. + * @member {Array.} strings + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.strings = $util.emptyArray; + + /** + * AttributeProto tensors. + * @member {Array.} tensors + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.tensors = $util.emptyArray; + + /** + * AttributeProto graphs. + * @member {Array.} graphs + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.graphs = $util.emptyArray; + + /** + * AttributeProto sparseTensors. + * @member {Array.} sparseTensors + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.sparseTensors = $util.emptyArray; + + /** + * AttributeProto typeProtos. + * @member {Array.} typeProtos + * @memberof onnx.AttributeProto + * @instance + */ + AttributeProto.prototype.typeProtos = $util.emptyArray; + + /** + * Creates a new AttributeProto instance using the specified properties. + * @function create + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto=} [properties] Properties to set + * @returns {onnx.AttributeProto} AttributeProto instance + */ + AttributeProto.create = function create(properties) { + return new AttributeProto(properties); + }; + + /** + * Encodes the specified AttributeProto message. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. + * @function encode + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + AttributeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); + if (message.f != null && Object.hasOwnProperty.call(message, "f")) + writer.uint32(/* id 2, wireType 5 =*/21).float(message.f); + if (message.i != null && Object.hasOwnProperty.call(message, "i")) + writer.uint32(/* id 3, wireType 0 =*/24).int64(message.i); + if (message.s != null && Object.hasOwnProperty.call(message, "s")) + writer.uint32(/* id 4, wireType 2 =*/34).bytes(message.s); + if (message.t != null && Object.hasOwnProperty.call(message, "t")) + $root.onnx.TensorProto.encode(message.t, writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.g != null && Object.hasOwnProperty.call(message, "g")) + $root.onnx.GraphProto.encode(message.g, writer.uint32(/* id 6, wireType 2 =*/50).fork()).ldelim(); + if (message.floats != null && message.floats.length) { + writer.uint32(/* id 7, wireType 2 =*/58).fork(); + for (var i = 0; i < message.floats.length; ++i) + writer.float(message.floats[i]); + writer.ldelim(); + } + if (message.ints != null && message.ints.length) { + writer.uint32(/* id 8, wireType 2 =*/66).fork(); + for (var i = 0; i < message.ints.length; ++i) + writer.int64(message.ints[i]); + writer.ldelim(); + } + if (message.strings != null && message.strings.length) + for (var i = 0; i < message.strings.length; ++i) + writer.uint32(/* id 9, wireType 2 =*/74).bytes(message.strings[i]); + if (message.tensors != null && message.tensors.length) + for (var i = 0; i < message.tensors.length; ++i) + $root.onnx.TensorProto.encode(message.tensors[i], writer.uint32(/* id 10, wireType 2 =*/82).fork()).ldelim(); + if (message.graphs != null && message.graphs.length) + for (var i = 0; i < message.graphs.length; ++i) + $root.onnx.GraphProto.encode(message.graphs[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 13, wireType 2 =*/106).string(message.docString); + if (message.tp != null && Object.hasOwnProperty.call(message, "tp")) + $root.onnx.TypeProto.encode(message.tp, writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); + if (message.typeProtos != null && message.typeProtos.length) + for (var i = 0; i < message.typeProtos.length; ++i) + $root.onnx.TypeProto.encode(message.typeProtos[i], writer.uint32(/* id 15, wireType 2 =*/122).fork()).ldelim(); + if (message.type != null && Object.hasOwnProperty.call(message, "type")) + writer.uint32(/* id 20, wireType 0 =*/160).int32(message.type); + if (message.refAttrName != null && Object.hasOwnProperty.call(message, "refAttrName")) + writer.uint32(/* id 21, wireType 2 =*/170).string(message.refAttrName); + if (message.sparseTensor != null && Object.hasOwnProperty.call(message, "sparseTensor")) + $root.onnx.SparseTensorProto.encode(message.sparseTensor, writer.uint32(/* id 22, wireType 2 =*/178).fork()).ldelim(); + if (message.sparseTensors != null && message.sparseTensors.length) + for (var i = 0; i < message.sparseTensors.length; ++i) + $root.onnx.SparseTensorProto.encode(message.sparseTensors[i], writer.uint32(/* id 23, wireType 2 =*/186).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified AttributeProto message, length delimited. Does not implicitly {@link onnx.AttributeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.AttributeProto + * @static + * @param {onnx.IAttributeProto} message AttributeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + AttributeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an AttributeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.AttributeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.AttributeProto} AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + AttributeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.AttributeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 21: { + message.refAttrName = reader.string(); + break; + } + case 13: { + message.docString = reader.string(); + break; + } + case 20: { + message.type = reader.int32(); + break; + } + case 2: { + message.f = reader.float(); + break; + } + case 3: { + message.i = reader.int64(); + break; + } + case 4: { + message.s = reader.bytes(); + break; + } + case 5: { + message.t = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 6: { + message.g = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 22: { + message.sparseTensor = $root.onnx.SparseTensorProto.decode(reader, reader.uint32()); + break; + } + case 14: { + message.tp = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + case 7: { + if (!(message.floats && message.floats.length)) + message.floats = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.floats.push(reader.float()); + } else + message.floats.push(reader.float()); + break; + } + case 8: { + if (!(message.ints && message.ints.length)) + message.ints = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.ints.push(reader.int64()); + } else + message.ints.push(reader.int64()); + break; + } + case 9: { + if (!(message.strings && message.strings.length)) + message.strings = []; + message.strings.push(reader.bytes()); + break; + } + case 10: { + if (!(message.tensors && message.tensors.length)) + message.tensors = []; + message.tensors.push($root.onnx.TensorProto.decode(reader, reader.uint32())); + break; + } + case 11: { + if (!(message.graphs && message.graphs.length)) + message.graphs = []; + message.graphs.push($root.onnx.GraphProto.decode(reader, reader.uint32())); + break; + } + case 23: { + if (!(message.sparseTensors && message.sparseTensors.length)) + message.sparseTensors = []; + message.sparseTensors.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); + break; + } + case 15: { + if (!(message.typeProtos && message.typeProtos.length)) + message.typeProtos = []; + message.typeProtos.push($root.onnx.TypeProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an AttributeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.AttributeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.AttributeProto} AttributeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + AttributeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an AttributeProto message. + * @function verify + * @memberof onnx.AttributeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + AttributeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.refAttrName != null && message.hasOwnProperty("refAttrName")) + if (!$util.isString(message.refAttrName)) + return "refAttrName: string expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.type != null && message.hasOwnProperty("type")) + switch (message.type) { + default: + return "type: enum value expected"; + case 0: + case 1: + case 2: + case 3: + case 4: + case 5: + case 11: + case 13: + case 6: + case 7: + case 8: + case 9: + case 10: + case 12: + case 14: + break; + } + if (message.f != null && message.hasOwnProperty("f")) + if (typeof message.f !== "number") + return "f: number expected"; + if (message.i != null && message.hasOwnProperty("i")) + if (!$util.isInteger(message.i) && !(message.i && $util.isInteger(message.i.low) && $util.isInteger(message.i.high))) + return "i: integer|Long expected"; + if (message.s != null && message.hasOwnProperty("s")) + if (!(message.s && typeof message.s.length === "number" || $util.isString(message.s))) + return "s: buffer expected"; + if (message.t != null && message.hasOwnProperty("t")) { + var error = $root.onnx.TensorProto.verify(message.t); + if (error) + return "t." + error; + } + if (message.g != null && message.hasOwnProperty("g")) { + var error = $root.onnx.GraphProto.verify(message.g); + if (error) + return "g." + error; + } + if (message.sparseTensor != null && message.hasOwnProperty("sparseTensor")) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseTensor); + if (error) + return "sparseTensor." + error; + } + if (message.tp != null && message.hasOwnProperty("tp")) { + var error = $root.onnx.TypeProto.verify(message.tp); + if (error) + return "tp." + error; + } + if (message.floats != null && message.hasOwnProperty("floats")) { + if (!Array.isArray(message.floats)) + return "floats: array expected"; + for (var i = 0; i < message.floats.length; ++i) + if (typeof message.floats[i] !== "number") + return "floats: number[] expected"; + } + if (message.ints != null && message.hasOwnProperty("ints")) { + if (!Array.isArray(message.ints)) + return "ints: array expected"; + for (var i = 0; i < message.ints.length; ++i) + if (!$util.isInteger(message.ints[i]) && !(message.ints[i] && $util.isInteger(message.ints[i].low) && $util.isInteger(message.ints[i].high))) + return "ints: integer|Long[] expected"; + } + if (message.strings != null && message.hasOwnProperty("strings")) { + if (!Array.isArray(message.strings)) + return "strings: array expected"; + for (var i = 0; i < message.strings.length; ++i) + if (!(message.strings[i] && typeof message.strings[i].length === "number" || $util.isString(message.strings[i]))) + return "strings: buffer[] expected"; + } + if (message.tensors != null && message.hasOwnProperty("tensors")) { + if (!Array.isArray(message.tensors)) + return "tensors: array expected"; + for (var i = 0; i < message.tensors.length; ++i) { + var error = $root.onnx.TensorProto.verify(message.tensors[i]); + if (error) + return "tensors." + error; + } + } + if (message.graphs != null && message.hasOwnProperty("graphs")) { + if (!Array.isArray(message.graphs)) + return "graphs: array expected"; + for (var i = 0; i < message.graphs.length; ++i) { + var error = $root.onnx.GraphProto.verify(message.graphs[i]); + if (error) + return "graphs." + error; + } + } + if (message.sparseTensors != null && message.hasOwnProperty("sparseTensors")) { + if (!Array.isArray(message.sparseTensors)) + return "sparseTensors: array expected"; + for (var i = 0; i < message.sparseTensors.length; ++i) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseTensors[i]); + if (error) + return "sparseTensors." + error; + } + } + if (message.typeProtos != null && message.hasOwnProperty("typeProtos")) { + if (!Array.isArray(message.typeProtos)) + return "typeProtos: array expected"; + for (var i = 0; i < message.typeProtos.length; ++i) { + var error = $root.onnx.TypeProto.verify(message.typeProtos[i]); + if (error) + return "typeProtos." + error; + } + } + return null; + }; + + /** + * Creates an AttributeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.AttributeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.AttributeProto} AttributeProto + */ + AttributeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.AttributeProto) + return object; + var message = new $root.onnx.AttributeProto(); + if (object.name != null) + message.name = String(object.name); + if (object.refAttrName != null) + message.refAttrName = String(object.refAttrName); + if (object.docString != null) + message.docString = String(object.docString); + switch (object.type) { + default: + if (typeof object.type === "number") { + message.type = object.type; + break; + } + break; + case "UNDEFINED": + case 0: + message.type = 0; + break; + case "FLOAT": + case 1: + message.type = 1; + break; + case "INT": + case 2: + message.type = 2; + break; + case "STRING": + case 3: + message.type = 3; + break; + case "TENSOR": + case 4: + message.type = 4; + break; + case "GRAPH": + case 5: + message.type = 5; + break; + case "SPARSE_TENSOR": + case 11: + message.type = 11; + break; + case "TYPE_PROTO": + case 13: + message.type = 13; + break; + case "FLOATS": + case 6: + message.type = 6; + break; + case "INTS": + case 7: + message.type = 7; + break; + case "STRINGS": + case 8: + message.type = 8; + break; + case "TENSORS": + case 9: + message.type = 9; + break; + case "GRAPHS": + case 10: + message.type = 10; + break; + case "SPARSE_TENSORS": + case 12: + message.type = 12; + break; + case "TYPE_PROTOS": + case 14: + message.type = 14; + break; + } + if (object.f != null) + message.f = Number(object.f); + if (object.i != null) + if ($util.Long) + (message.i = $util.Long.fromValue(object.i)).unsigned = false; + else if (typeof object.i === "string") + message.i = parseInt(object.i, 10); + else if (typeof object.i === "number") + message.i = object.i; + else if (typeof object.i === "object") + message.i = new $util.LongBits(object.i.low >>> 0, object.i.high >>> 0).toNumber(); + if (object.s != null) + if (typeof object.s === "string") + $util.base64.decode(object.s, message.s = $util.newBuffer($util.base64.length(object.s)), 0); + else if (object.s.length >= 0) + message.s = object.s; + if (object.t != null) { + if (typeof object.t !== "object") + throw TypeError(".onnx.AttributeProto.t: object expected"); + message.t = $root.onnx.TensorProto.fromObject(object.t); + } + if (object.g != null) { + if (typeof object.g !== "object") + throw TypeError(".onnx.AttributeProto.g: object expected"); + message.g = $root.onnx.GraphProto.fromObject(object.g); + } + if (object.sparseTensor != null) { + if (typeof object.sparseTensor !== "object") + throw TypeError(".onnx.AttributeProto.sparseTensor: object expected"); + message.sparseTensor = $root.onnx.SparseTensorProto.fromObject(object.sparseTensor); + } + if (object.tp != null) { + if (typeof object.tp !== "object") + throw TypeError(".onnx.AttributeProto.tp: object expected"); + message.tp = $root.onnx.TypeProto.fromObject(object.tp); + } + if (object.floats) { + if (!Array.isArray(object.floats)) + throw TypeError(".onnx.AttributeProto.floats: array expected"); + message.floats = []; + for (var i = 0; i < object.floats.length; ++i) + message.floats[i] = Number(object.floats[i]); + } + if (object.ints) { + if (!Array.isArray(object.ints)) + throw TypeError(".onnx.AttributeProto.ints: array expected"); + message.ints = []; + for (var i = 0; i < object.ints.length; ++i) + if ($util.Long) + (message.ints[i] = $util.Long.fromValue(object.ints[i])).unsigned = false; + else if (typeof object.ints[i] === "string") + message.ints[i] = parseInt(object.ints[i], 10); + else if (typeof object.ints[i] === "number") + message.ints[i] = object.ints[i]; + else if (typeof object.ints[i] === "object") + message.ints[i] = new $util.LongBits(object.ints[i].low >>> 0, object.ints[i].high >>> 0).toNumber(); + } + if (object.strings) { + if (!Array.isArray(object.strings)) + throw TypeError(".onnx.AttributeProto.strings: array expected"); + message.strings = []; + for (var i = 0; i < object.strings.length; ++i) + if (typeof object.strings[i] === "string") + $util.base64.decode(object.strings[i], message.strings[i] = $util.newBuffer($util.base64.length(object.strings[i])), 0); + else if (object.strings[i].length >= 0) + message.strings[i] = object.strings[i]; + } + if (object.tensors) { + if (!Array.isArray(object.tensors)) + throw TypeError(".onnx.AttributeProto.tensors: array expected"); + message.tensors = []; + for (var i = 0; i < object.tensors.length; ++i) { + if (typeof object.tensors[i] !== "object") + throw TypeError(".onnx.AttributeProto.tensors: object expected"); + message.tensors[i] = $root.onnx.TensorProto.fromObject(object.tensors[i]); + } + } + if (object.graphs) { + if (!Array.isArray(object.graphs)) + throw TypeError(".onnx.AttributeProto.graphs: array expected"); + message.graphs = []; + for (var i = 0; i < object.graphs.length; ++i) { + if (typeof object.graphs[i] !== "object") + throw TypeError(".onnx.AttributeProto.graphs: object expected"); + message.graphs[i] = $root.onnx.GraphProto.fromObject(object.graphs[i]); + } + } + if (object.sparseTensors) { + if (!Array.isArray(object.sparseTensors)) + throw TypeError(".onnx.AttributeProto.sparseTensors: array expected"); + message.sparseTensors = []; + for (var i = 0; i < object.sparseTensors.length; ++i) { + if (typeof object.sparseTensors[i] !== "object") + throw TypeError(".onnx.AttributeProto.sparseTensors: object expected"); + message.sparseTensors[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseTensors[i]); + } + } + if (object.typeProtos) { + if (!Array.isArray(object.typeProtos)) + throw TypeError(".onnx.AttributeProto.typeProtos: array expected"); + message.typeProtos = []; + for (var i = 0; i < object.typeProtos.length; ++i) { + if (typeof object.typeProtos[i] !== "object") + throw TypeError(".onnx.AttributeProto.typeProtos: object expected"); + message.typeProtos[i] = $root.onnx.TypeProto.fromObject(object.typeProtos[i]); + } + } + return message; + }; + + /** + * Creates a plain object from an AttributeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.AttributeProto + * @static + * @param {onnx.AttributeProto} message AttributeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + AttributeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.floats = []; + object.ints = []; + object.strings = []; + object.tensors = []; + object.graphs = []; + object.typeProtos = []; + object.sparseTensors = []; + } + if (options.defaults) { + object.name = ""; + object.f = 0; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.i = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.i = options.longs === String ? "0" : 0; + if (options.bytes === String) + object.s = ""; + else { + object.s = []; + if (options.bytes !== Array) + object.s = $util.newBuffer(object.s); + } + object.t = null; + object.g = null; + object.docString = ""; + object.tp = null; + object.type = options.enums === String ? "UNDEFINED" : 0; + object.refAttrName = ""; + object.sparseTensor = null; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.f != null && message.hasOwnProperty("f")) + object.f = options.json && !isFinite(message.f) ? String(message.f) : message.f; + if (message.i != null && message.hasOwnProperty("i")) + if (typeof message.i === "number") + object.i = options.longs === String ? String(message.i) : message.i; + else + object.i = options.longs === String ? $util.Long.prototype.toString.call(message.i) : options.longs === Number ? new $util.LongBits(message.i.low >>> 0, message.i.high >>> 0).toNumber() : message.i; + if (message.s != null && message.hasOwnProperty("s")) + object.s = options.bytes === String ? $util.base64.encode(message.s, 0, message.s.length) : options.bytes === Array ? Array.prototype.slice.call(message.s) : message.s; + if (message.t != null && message.hasOwnProperty("t")) + object.t = $root.onnx.TensorProto.toObject(message.t, options); + if (message.g != null && message.hasOwnProperty("g")) + object.g = $root.onnx.GraphProto.toObject(message.g, options); + if (message.floats && message.floats.length) { + object.floats = []; + for (var j = 0; j < message.floats.length; ++j) + object.floats[j] = options.json && !isFinite(message.floats[j]) ? String(message.floats[j]) : message.floats[j]; + } + if (message.ints && message.ints.length) { + object.ints = []; + for (var j = 0; j < message.ints.length; ++j) + if (typeof message.ints[j] === "number") + object.ints[j] = options.longs === String ? String(message.ints[j]) : message.ints[j]; + else + object.ints[j] = options.longs === String ? $util.Long.prototype.toString.call(message.ints[j]) : options.longs === Number ? new $util.LongBits(message.ints[j].low >>> 0, message.ints[j].high >>> 0).toNumber() : message.ints[j]; + } + if (message.strings && message.strings.length) { + object.strings = []; + for (var j = 0; j < message.strings.length; ++j) + object.strings[j] = options.bytes === String ? $util.base64.encode(message.strings[j], 0, message.strings[j].length) : options.bytes === Array ? Array.prototype.slice.call(message.strings[j]) : message.strings[j]; + } + if (message.tensors && message.tensors.length) { + object.tensors = []; + for (var j = 0; j < message.tensors.length; ++j) + object.tensors[j] = $root.onnx.TensorProto.toObject(message.tensors[j], options); + } + if (message.graphs && message.graphs.length) { + object.graphs = []; + for (var j = 0; j < message.graphs.length; ++j) + object.graphs[j] = $root.onnx.GraphProto.toObject(message.graphs[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.tp != null && message.hasOwnProperty("tp")) + object.tp = $root.onnx.TypeProto.toObject(message.tp, options); + if (message.typeProtos && message.typeProtos.length) { + object.typeProtos = []; + for (var j = 0; j < message.typeProtos.length; ++j) + object.typeProtos[j] = $root.onnx.TypeProto.toObject(message.typeProtos[j], options); + } + if (message.type != null && message.hasOwnProperty("type")) + object.type = options.enums === String ? $root.onnx.AttributeProto.AttributeType[message.type] === undefined ? message.type : $root.onnx.AttributeProto.AttributeType[message.type] : message.type; + if (message.refAttrName != null && message.hasOwnProperty("refAttrName")) + object.refAttrName = message.refAttrName; + if (message.sparseTensor != null && message.hasOwnProperty("sparseTensor")) + object.sparseTensor = $root.onnx.SparseTensorProto.toObject(message.sparseTensor, options); + if (message.sparseTensors && message.sparseTensors.length) { + object.sparseTensors = []; + for (var j = 0; j < message.sparseTensors.length; ++j) + object.sparseTensors[j] = $root.onnx.SparseTensorProto.toObject(message.sparseTensors[j], options); + } + return object; + }; + + /** + * Converts this AttributeProto to JSON. + * @function toJSON + * @memberof onnx.AttributeProto + * @instance + * @returns {Object.} JSON object + */ + AttributeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for AttributeProto + * @function getTypeUrl + * @memberof onnx.AttributeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + AttributeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.AttributeProto"; + }; + + /** + * AttributeType enum. + * @name onnx.AttributeProto.AttributeType + * @enum {number} + * @property {number} UNDEFINED=0 UNDEFINED value + * @property {number} FLOAT=1 FLOAT value + * @property {number} INT=2 INT value + * @property {number} STRING=3 STRING value + * @property {number} TENSOR=4 TENSOR value + * @property {number} GRAPH=5 GRAPH value + * @property {number} SPARSE_TENSOR=11 SPARSE_TENSOR value + * @property {number} TYPE_PROTO=13 TYPE_PROTO value + * @property {number} FLOATS=6 FLOATS value + * @property {number} INTS=7 INTS value + * @property {number} STRINGS=8 STRINGS value + * @property {number} TENSORS=9 TENSORS value + * @property {number} GRAPHS=10 GRAPHS value + * @property {number} SPARSE_TENSORS=12 SPARSE_TENSORS value + * @property {number} TYPE_PROTOS=14 TYPE_PROTOS value + */ + AttributeProto.AttributeType = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "UNDEFINED"] = 0; + values[valuesById[1] = "FLOAT"] = 1; + values[valuesById[2] = "INT"] = 2; + values[valuesById[3] = "STRING"] = 3; + values[valuesById[4] = "TENSOR"] = 4; + values[valuesById[5] = "GRAPH"] = 5; + values[valuesById[11] = "SPARSE_TENSOR"] = 11; + values[valuesById[13] = "TYPE_PROTO"] = 13; + values[valuesById[6] = "FLOATS"] = 6; + values[valuesById[7] = "INTS"] = 7; + values[valuesById[8] = "STRINGS"] = 8; + values[valuesById[9] = "TENSORS"] = 9; + values[valuesById[10] = "GRAPHS"] = 10; + values[valuesById[12] = "SPARSE_TENSORS"] = 12; + values[valuesById[14] = "TYPE_PROTOS"] = 14; + return values; + })(); + + return AttributeProto; + })(); + + onnx.ValueInfoProto = (function() { + + /** + * Properties of a ValueInfoProto. + * @memberof onnx + * @interface IValueInfoProto + * @property {string|null} [name] ValueInfoProto name + * @property {onnx.ITypeProto|null} [type] ValueInfoProto type + * @property {string|null} [docString] ValueInfoProto docString + */ + + /** + * Constructs a new ValueInfoProto. + * @memberof onnx + * @classdesc Represents a ValueInfoProto. + * @implements IValueInfoProto + * @constructor + * @param {onnx.IValueInfoProto=} [properties] Properties to set + */ + function ValueInfoProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * ValueInfoProto name. + * @member {string} name + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.name = ""; + + /** + * ValueInfoProto type. + * @member {onnx.ITypeProto|null|undefined} type + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.type = null; + + /** + * ValueInfoProto docString. + * @member {string} docString + * @memberof onnx.ValueInfoProto + * @instance + */ + ValueInfoProto.prototype.docString = ""; + + /** + * Creates a new ValueInfoProto instance using the specified properties. + * @function create + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto=} [properties] Properties to set + * @returns {onnx.ValueInfoProto} ValueInfoProto instance + */ + ValueInfoProto.create = function create(properties) { + return new ValueInfoProto(properties); + }; + + /** + * Encodes the specified ValueInfoProto message. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. + * @function encode + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ValueInfoProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); + if (message.type != null && Object.hasOwnProperty.call(message, "type")) + $root.onnx.TypeProto.encode(message.type, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.docString); + return writer; + }; + + /** + * Encodes the specified ValueInfoProto message, length delimited. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.IValueInfoProto} message ValueInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ValueInfoProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.ValueInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.ValueInfoProto} ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ValueInfoProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.ValueInfoProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 2: { + message.type = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + case 3: { + message.docString = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a ValueInfoProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.ValueInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.ValueInfoProto} ValueInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ValueInfoProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a ValueInfoProto message. + * @function verify + * @memberof onnx.ValueInfoProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + ValueInfoProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.type != null && message.hasOwnProperty("type")) { + var error = $root.onnx.TypeProto.verify(message.type); + if (error) + return "type." + error; + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + return null; + }; + + /** + * Creates a ValueInfoProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.ValueInfoProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.ValueInfoProto} ValueInfoProto + */ + ValueInfoProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.ValueInfoProto) + return object; + var message = new $root.onnx.ValueInfoProto(); + if (object.name != null) + message.name = String(object.name); + if (object.type != null) { + if (typeof object.type !== "object") + throw TypeError(".onnx.ValueInfoProto.type: object expected"); + message.type = $root.onnx.TypeProto.fromObject(object.type); + } + if (object.docString != null) + message.docString = String(object.docString); + return message; + }; + + /** + * Creates a plain object from a ValueInfoProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.ValueInfoProto + * @static + * @param {onnx.ValueInfoProto} message ValueInfoProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + ValueInfoProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.name = ""; + object.type = null; + object.docString = ""; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.type != null && message.hasOwnProperty("type")) + object.type = $root.onnx.TypeProto.toObject(message.type, options); + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + return object; + }; + + /** + * Converts this ValueInfoProto to JSON. + * @function toJSON + * @memberof onnx.ValueInfoProto + * @instance + * @returns {Object.} JSON object + */ + ValueInfoProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for ValueInfoProto + * @function getTypeUrl + * @memberof onnx.ValueInfoProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + ValueInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.ValueInfoProto"; + }; + + return ValueInfoProto; + })(); + + onnx.NodeProto = (function() { + + /** + * Properties of a NodeProto. + * @memberof onnx + * @interface INodeProto + * @property {Array.|null} [input] NodeProto input + * @property {Array.|null} [output] NodeProto output + * @property {string|null} [name] NodeProto name + * @property {string|null} [opType] NodeProto opType + * @property {string|null} [domain] NodeProto domain + * @property {Array.|null} [attribute] NodeProto attribute + * @property {string|null} [docString] NodeProto docString + */ + + /** + * Constructs a new NodeProto. + * @memberof onnx + * @classdesc Represents a NodeProto. + * @implements INodeProto + * @constructor + * @param {onnx.INodeProto=} [properties] Properties to set + */ + function NodeProto(properties) { + this.input = []; + this.output = []; + this.attribute = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * NodeProto input. + * @member {Array.} input + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.input = $util.emptyArray; + + /** + * NodeProto output. + * @member {Array.} output + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.output = $util.emptyArray; + + /** + * NodeProto name. + * @member {string} name + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.name = ""; + + /** + * NodeProto opType. + * @member {string} opType + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.opType = ""; + + /** + * NodeProto domain. + * @member {string} domain + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.domain = ""; + + /** + * NodeProto attribute. + * @member {Array.} attribute + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.attribute = $util.emptyArray; + + /** + * NodeProto docString. + * @member {string} docString + * @memberof onnx.NodeProto + * @instance + */ + NodeProto.prototype.docString = ""; + + /** + * Creates a new NodeProto instance using the specified properties. + * @function create + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto=} [properties] Properties to set + * @returns {onnx.NodeProto} NodeProto instance + */ + NodeProto.create = function create(properties) { + return new NodeProto(properties); + }; + + /** + * Encodes the specified NodeProto message. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @function encode + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto} message NodeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + NodeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.input[i]); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.output[i]); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.name); + if (message.opType != null && Object.hasOwnProperty.call(message, "opType")) + writer.uint32(/* id 4, wireType 2 =*/34).string(message.opType); + if (message.attribute != null && message.attribute.length) + for (var i = 0; i < message.attribute.length; ++i) + $root.onnx.AttributeProto.encode(message.attribute[i], writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.docString); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 7, wireType 2 =*/58).string(message.domain); + return writer; + }; + + /** + * Encodes the specified NodeProto message, length delimited. Does not implicitly {@link onnx.NodeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.NodeProto + * @static + * @param {onnx.INodeProto} message NodeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + NodeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a NodeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.NodeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.NodeProto} NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + NodeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.NodeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.input && message.input.length)) + message.input = []; + message.input.push(reader.string()); + break; + } + case 2: { + if (!(message.output && message.output.length)) + message.output = []; + message.output.push(reader.string()); + break; + } + case 3: { + message.name = reader.string(); + break; + } + case 4: { + message.opType = reader.string(); + break; + } + case 7: { + message.domain = reader.string(); + break; + } + case 5: { + if (!(message.attribute && message.attribute.length)) + message.attribute = []; + message.attribute.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); + break; + } + case 6: { + message.docString = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a NodeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.NodeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.NodeProto} NodeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + NodeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a NodeProto message. + * @function verify + * @memberof onnx.NodeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + NodeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) + return "input: array expected"; + for (var i = 0; i < message.input.length; ++i) + if (!$util.isString(message.input[i])) + return "input: string[] expected"; + } + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) + return "output: array expected"; + for (var i = 0; i < message.output.length; ++i) + if (!$util.isString(message.output[i])) + return "output: string[] expected"; + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.opType != null && message.hasOwnProperty("opType")) + if (!$util.isString(message.opType)) + return "opType: string expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + if (message.attribute != null && message.hasOwnProperty("attribute")) { + if (!Array.isArray(message.attribute)) + return "attribute: array expected"; + for (var i = 0; i < message.attribute.length; ++i) { + var error = $root.onnx.AttributeProto.verify(message.attribute[i]); + if (error) + return "attribute." + error; + } + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + return null; + }; + + /** + * Creates a NodeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.NodeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.NodeProto} NodeProto + */ + NodeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.NodeProto) + return object; + var message = new $root.onnx.NodeProto(); + if (object.input) { + if (!Array.isArray(object.input)) + throw TypeError(".onnx.NodeProto.input: array expected"); + message.input = []; + for (var i = 0; i < object.input.length; ++i) + message.input[i] = String(object.input[i]); + } + if (object.output) { + if (!Array.isArray(object.output)) + throw TypeError(".onnx.NodeProto.output: array expected"); + message.output = []; + for (var i = 0; i < object.output.length; ++i) + message.output[i] = String(object.output[i]); + } + if (object.name != null) + message.name = String(object.name); + if (object.opType != null) + message.opType = String(object.opType); + if (object.domain != null) + message.domain = String(object.domain); + if (object.attribute) { + if (!Array.isArray(object.attribute)) + throw TypeError(".onnx.NodeProto.attribute: array expected"); + message.attribute = []; + for (var i = 0; i < object.attribute.length; ++i) { + if (typeof object.attribute[i] !== "object") + throw TypeError(".onnx.NodeProto.attribute: object expected"); + message.attribute[i] = $root.onnx.AttributeProto.fromObject(object.attribute[i]); + } + } + if (object.docString != null) + message.docString = String(object.docString); + return message; + }; + + /** + * Creates a plain object from a NodeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.NodeProto + * @static + * @param {onnx.NodeProto} message NodeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + NodeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.input = []; + object.output = []; + object.attribute = []; + } + if (options.defaults) { + object.name = ""; + object.opType = ""; + object.docString = ""; + object.domain = ""; + } + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = message.input[j]; + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = message.output[j]; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.opType != null && message.hasOwnProperty("opType")) + object.opType = message.opType; + if (message.attribute && message.attribute.length) { + object.attribute = []; + for (var j = 0; j < message.attribute.length; ++j) + object.attribute[j] = $root.onnx.AttributeProto.toObject(message.attribute[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + return object; + }; + + /** + * Converts this NodeProto to JSON. + * @function toJSON + * @memberof onnx.NodeProto + * @instance + * @returns {Object.} JSON object + */ + NodeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for NodeProto + * @function getTypeUrl + * @memberof onnx.NodeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + NodeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.NodeProto"; + }; + + return NodeProto; + })(); + + onnx.TrainingInfoProto = (function() { + + /** + * Properties of a TrainingInfoProto. + * @memberof onnx + * @interface ITrainingInfoProto + * @property {onnx.IGraphProto|null} [initialization] TrainingInfoProto initialization + * @property {onnx.IGraphProto|null} [algorithm] TrainingInfoProto algorithm + * @property {Array.|null} [initializationBinding] TrainingInfoProto initializationBinding + * @property {Array.|null} [updateBinding] TrainingInfoProto updateBinding + */ + + /** + * Constructs a new TrainingInfoProto. + * @memberof onnx + * @classdesc Represents a TrainingInfoProto. + * @implements ITrainingInfoProto + * @constructor + * @param {onnx.ITrainingInfoProto=} [properties] Properties to set + */ + function TrainingInfoProto(properties) { + this.initializationBinding = []; + this.updateBinding = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TrainingInfoProto initialization. + * @member {onnx.IGraphProto|null|undefined} initialization + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.initialization = null; + + /** + * TrainingInfoProto algorithm. + * @member {onnx.IGraphProto|null|undefined} algorithm + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.algorithm = null; + + /** + * TrainingInfoProto initializationBinding. + * @member {Array.} initializationBinding + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.initializationBinding = $util.emptyArray; + + /** + * TrainingInfoProto updateBinding. + * @member {Array.} updateBinding + * @memberof onnx.TrainingInfoProto + * @instance + */ + TrainingInfoProto.prototype.updateBinding = $util.emptyArray; + + /** + * Creates a new TrainingInfoProto instance using the specified properties. + * @function create + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto=} [properties] Properties to set + * @returns {onnx.TrainingInfoProto} TrainingInfoProto instance + */ + TrainingInfoProto.create = function create(properties) { + return new TrainingInfoProto(properties); + }; + + /** + * Encodes the specified TrainingInfoProto message. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. + * @function encode + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TrainingInfoProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.initialization != null && Object.hasOwnProperty.call(message, "initialization")) + $root.onnx.GraphProto.encode(message.initialization, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.algorithm != null && Object.hasOwnProperty.call(message, "algorithm")) + $root.onnx.GraphProto.encode(message.algorithm, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.initializationBinding != null && message.initializationBinding.length) + for (var i = 0; i < message.initializationBinding.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.initializationBinding[i], writer.uint32(/* id 3, wireType 2 =*/26).fork()).ldelim(); + if (message.updateBinding != null && message.updateBinding.length) + for (var i = 0; i < message.updateBinding.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.updateBinding[i], writer.uint32(/* id 4, wireType 2 =*/34).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TrainingInfoProto message, length delimited. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.ITrainingInfoProto} message TrainingInfoProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TrainingInfoProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TrainingInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TrainingInfoProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TrainingInfoProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.initialization = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 2: { + message.algorithm = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 3: { + if (!(message.initializationBinding && message.initializationBinding.length)) + message.initializationBinding = []; + message.initializationBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 4: { + if (!(message.updateBinding && message.updateBinding.length)) + message.updateBinding = []; + message.updateBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TrainingInfoProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TrainingInfoProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TrainingInfoProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TrainingInfoProto message. + * @function verify + * @memberof onnx.TrainingInfoProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TrainingInfoProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.initialization != null && message.hasOwnProperty("initialization")) { + var error = $root.onnx.GraphProto.verify(message.initialization); + if (error) + return "initialization." + error; + } + if (message.algorithm != null && message.hasOwnProperty("algorithm")) { + var error = $root.onnx.GraphProto.verify(message.algorithm); + if (error) + return "algorithm." + error; + } + if (message.initializationBinding != null && message.hasOwnProperty("initializationBinding")) { + if (!Array.isArray(message.initializationBinding)) + return "initializationBinding: array expected"; + for (var i = 0; i < message.initializationBinding.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.initializationBinding[i]); + if (error) + return "initializationBinding." + error; + } + } + if (message.updateBinding != null && message.hasOwnProperty("updateBinding")) { + if (!Array.isArray(message.updateBinding)) + return "updateBinding: array expected"; + for (var i = 0; i < message.updateBinding.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.updateBinding[i]); + if (error) + return "updateBinding." + error; + } + } + return null; + }; + + /** + * Creates a TrainingInfoProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TrainingInfoProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TrainingInfoProto} TrainingInfoProto + */ + TrainingInfoProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TrainingInfoProto) + return object; + var message = new $root.onnx.TrainingInfoProto(); + if (object.initialization != null) { + if (typeof object.initialization !== "object") + throw TypeError(".onnx.TrainingInfoProto.initialization: object expected"); + message.initialization = $root.onnx.GraphProto.fromObject(object.initialization); + } + if (object.algorithm != null) { + if (typeof object.algorithm !== "object") + throw TypeError(".onnx.TrainingInfoProto.algorithm: object expected"); + message.algorithm = $root.onnx.GraphProto.fromObject(object.algorithm); + } + if (object.initializationBinding) { + if (!Array.isArray(object.initializationBinding)) + throw TypeError(".onnx.TrainingInfoProto.initializationBinding: array expected"); + message.initializationBinding = []; + for (var i = 0; i < object.initializationBinding.length; ++i) { + if (typeof object.initializationBinding[i] !== "object") + throw TypeError(".onnx.TrainingInfoProto.initializationBinding: object expected"); + message.initializationBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.initializationBinding[i]); + } + } + if (object.updateBinding) { + if (!Array.isArray(object.updateBinding)) + throw TypeError(".onnx.TrainingInfoProto.updateBinding: array expected"); + message.updateBinding = []; + for (var i = 0; i < object.updateBinding.length; ++i) { + if (typeof object.updateBinding[i] !== "object") + throw TypeError(".onnx.TrainingInfoProto.updateBinding: object expected"); + message.updateBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.updateBinding[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TrainingInfoProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TrainingInfoProto + * @static + * @param {onnx.TrainingInfoProto} message TrainingInfoProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TrainingInfoProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.initializationBinding = []; + object.updateBinding = []; + } + if (options.defaults) { + object.initialization = null; + object.algorithm = null; + } + if (message.initialization != null && message.hasOwnProperty("initialization")) + object.initialization = $root.onnx.GraphProto.toObject(message.initialization, options); + if (message.algorithm != null && message.hasOwnProperty("algorithm")) + object.algorithm = $root.onnx.GraphProto.toObject(message.algorithm, options); + if (message.initializationBinding && message.initializationBinding.length) { + object.initializationBinding = []; + for (var j = 0; j < message.initializationBinding.length; ++j) + object.initializationBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.initializationBinding[j], options); + } + if (message.updateBinding && message.updateBinding.length) { + object.updateBinding = []; + for (var j = 0; j < message.updateBinding.length; ++j) + object.updateBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.updateBinding[j], options); + } + return object; + }; + + /** + * Converts this TrainingInfoProto to JSON. + * @function toJSON + * @memberof onnx.TrainingInfoProto + * @instance + * @returns {Object.} JSON object + */ + TrainingInfoProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TrainingInfoProto + * @function getTypeUrl + * @memberof onnx.TrainingInfoProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TrainingInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TrainingInfoProto"; + }; + + return TrainingInfoProto; + })(); + + onnx.ModelProto = (function() { + + /** + * Properties of a ModelProto. + * @memberof onnx + * @interface IModelProto + * @property {number|Long|null} [irVersion] ModelProto irVersion + * @property {Array.|null} [opsetImport] ModelProto opsetImport + * @property {string|null} [producerName] ModelProto producerName + * @property {string|null} [producerVersion] ModelProto producerVersion + * @property {string|null} [domain] ModelProto domain + * @property {number|Long|null} [modelVersion] ModelProto modelVersion + * @property {string|null} [docString] ModelProto docString + * @property {onnx.IGraphProto|null} [graph] ModelProto graph + * @property {Array.|null} [metadataProps] ModelProto metadataProps + * @property {Array.|null} [trainingInfo] ModelProto trainingInfo + * @property {Array.|null} [functions] ModelProto functions + */ + + /** + * Constructs a new ModelProto. + * @memberof onnx + * @classdesc Represents a ModelProto. + * @implements IModelProto + * @constructor + * @param {onnx.IModelProto=} [properties] Properties to set + */ + function ModelProto(properties) { + this.opsetImport = []; + this.metadataProps = []; + this.trainingInfo = []; + this.functions = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * ModelProto irVersion. + * @member {number|Long} irVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.irVersion = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * ModelProto opsetImport. + * @member {Array.} opsetImport + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.opsetImport = $util.emptyArray; + + /** + * ModelProto producerName. + * @member {string} producerName + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.producerName = ""; + + /** + * ModelProto producerVersion. + * @member {string} producerVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.producerVersion = ""; + + /** + * ModelProto domain. + * @member {string} domain + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.domain = ""; + + /** + * ModelProto modelVersion. + * @member {number|Long} modelVersion + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.modelVersion = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * ModelProto docString. + * @member {string} docString + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.docString = ""; + + /** + * ModelProto graph. + * @member {onnx.IGraphProto|null|undefined} graph + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.graph = null; + + /** + * ModelProto metadataProps. + * @member {Array.} metadataProps + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.metadataProps = $util.emptyArray; + + /** + * ModelProto trainingInfo. + * @member {Array.} trainingInfo + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.trainingInfo = $util.emptyArray; + + /** + * ModelProto functions. + * @member {Array.} functions + * @memberof onnx.ModelProto + * @instance + */ + ModelProto.prototype.functions = $util.emptyArray; + + /** + * Creates a new ModelProto instance using the specified properties. + * @function create + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto=} [properties] Properties to set + * @returns {onnx.ModelProto} ModelProto instance + */ + ModelProto.create = function create(properties) { + return new ModelProto(properties); + }; + + /** + * Encodes the specified ModelProto message. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @function encode + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto} message ModelProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ModelProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.irVersion != null && Object.hasOwnProperty.call(message, "irVersion")) + writer.uint32(/* id 1, wireType 0 =*/8).int64(message.irVersion); + if (message.producerName != null && Object.hasOwnProperty.call(message, "producerName")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.producerName); + if (message.producerVersion != null && Object.hasOwnProperty.call(message, "producerVersion")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.producerVersion); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 4, wireType 2 =*/34).string(message.domain); + if (message.modelVersion != null && Object.hasOwnProperty.call(message, "modelVersion")) + writer.uint32(/* id 5, wireType 0 =*/40).int64(message.modelVersion); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.docString); + if (message.graph != null && Object.hasOwnProperty.call(message, "graph")) + $root.onnx.GraphProto.encode(message.graph, writer.uint32(/* id 7, wireType 2 =*/58).fork()).ldelim(); + if (message.opsetImport != null && message.opsetImport.length) + for (var i = 0; i < message.opsetImport.length; ++i) + $root.onnx.OperatorSetIdProto.encode(message.opsetImport[i], writer.uint32(/* id 8, wireType 2 =*/66).fork()).ldelim(); + if (message.metadataProps != null && message.metadataProps.length) + for (var i = 0; i < message.metadataProps.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.metadataProps[i], writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); + if (message.trainingInfo != null && message.trainingInfo.length) + for (var i = 0; i < message.trainingInfo.length; ++i) + $root.onnx.TrainingInfoProto.encode(message.trainingInfo[i], writer.uint32(/* id 20, wireType 2 =*/162).fork()).ldelim(); + if (message.functions != null && message.functions.length) + for (var i = 0; i < message.functions.length; ++i) + $root.onnx.FunctionProto.encode(message.functions[i], writer.uint32(/* id 25, wireType 2 =*/202).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified ModelProto message, length delimited. Does not implicitly {@link onnx.ModelProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.ModelProto + * @static + * @param {onnx.IModelProto} message ModelProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + ModelProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a ModelProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.ModelProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.ModelProto} ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ModelProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.ModelProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.irVersion = reader.int64(); + break; + } + case 8: { + if (!(message.opsetImport && message.opsetImport.length)) + message.opsetImport = []; + message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); + break; + } + case 2: { + message.producerName = reader.string(); + break; + } + case 3: { + message.producerVersion = reader.string(); + break; + } + case 4: { + message.domain = reader.string(); + break; + } + case 5: { + message.modelVersion = reader.int64(); + break; + } + case 6: { + message.docString = reader.string(); + break; + } + case 7: { + message.graph = $root.onnx.GraphProto.decode(reader, reader.uint32()); + break; + } + case 14: { + if (!(message.metadataProps && message.metadataProps.length)) + message.metadataProps = []; + message.metadataProps.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 20: { + if (!(message.trainingInfo && message.trainingInfo.length)) + message.trainingInfo = []; + message.trainingInfo.push($root.onnx.TrainingInfoProto.decode(reader, reader.uint32())); + break; + } + case 25: { + if (!(message.functions && message.functions.length)) + message.functions = []; + message.functions.push($root.onnx.FunctionProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a ModelProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.ModelProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.ModelProto} ModelProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + ModelProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a ModelProto message. + * @function verify + * @memberof onnx.ModelProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + ModelProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.irVersion != null && message.hasOwnProperty("irVersion")) + if (!$util.isInteger(message.irVersion) && !(message.irVersion && $util.isInteger(message.irVersion.low) && $util.isInteger(message.irVersion.high))) + return "irVersion: integer|Long expected"; + if (message.opsetImport != null && message.hasOwnProperty("opsetImport")) { + if (!Array.isArray(message.opsetImport)) + return "opsetImport: array expected"; + for (var i = 0; i < message.opsetImport.length; ++i) { + var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); + if (error) + return "opsetImport." + error; + } + } + if (message.producerName != null && message.hasOwnProperty("producerName")) + if (!$util.isString(message.producerName)) + return "producerName: string expected"; + if (message.producerVersion != null && message.hasOwnProperty("producerVersion")) + if (!$util.isString(message.producerVersion)) + return "producerVersion: string expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + if (message.modelVersion != null && message.hasOwnProperty("modelVersion")) + if (!$util.isInteger(message.modelVersion) && !(message.modelVersion && $util.isInteger(message.modelVersion.low) && $util.isInteger(message.modelVersion.high))) + return "modelVersion: integer|Long expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.graph != null && message.hasOwnProperty("graph")) { + var error = $root.onnx.GraphProto.verify(message.graph); + if (error) + return "graph." + error; + } + if (message.metadataProps != null && message.hasOwnProperty("metadataProps")) { + if (!Array.isArray(message.metadataProps)) + return "metadataProps: array expected"; + for (var i = 0; i < message.metadataProps.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.metadataProps[i]); + if (error) + return "metadataProps." + error; + } + } + if (message.trainingInfo != null && message.hasOwnProperty("trainingInfo")) { + if (!Array.isArray(message.trainingInfo)) + return "trainingInfo: array expected"; + for (var i = 0; i < message.trainingInfo.length; ++i) { + var error = $root.onnx.TrainingInfoProto.verify(message.trainingInfo[i]); + if (error) + return "trainingInfo." + error; + } + } + if (message.functions != null && message.hasOwnProperty("functions")) { + if (!Array.isArray(message.functions)) + return "functions: array expected"; + for (var i = 0; i < message.functions.length; ++i) { + var error = $root.onnx.FunctionProto.verify(message.functions[i]); + if (error) + return "functions." + error; + } + } + return null; + }; + + /** + * Creates a ModelProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.ModelProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.ModelProto} ModelProto + */ + ModelProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.ModelProto) + return object; + var message = new $root.onnx.ModelProto(); + if (object.irVersion != null) + if ($util.Long) + (message.irVersion = $util.Long.fromValue(object.irVersion)).unsigned = false; + else if (typeof object.irVersion === "string") + message.irVersion = parseInt(object.irVersion, 10); + else if (typeof object.irVersion === "number") + message.irVersion = object.irVersion; + else if (typeof object.irVersion === "object") + message.irVersion = new $util.LongBits(object.irVersion.low >>> 0, object.irVersion.high >>> 0).toNumber(); + if (object.opsetImport) { + if (!Array.isArray(object.opsetImport)) + throw TypeError(".onnx.ModelProto.opsetImport: array expected"); + message.opsetImport = []; + for (var i = 0; i < object.opsetImport.length; ++i) { + if (typeof object.opsetImport[i] !== "object") + throw TypeError(".onnx.ModelProto.opsetImport: object expected"); + message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); + } + } + if (object.producerName != null) + message.producerName = String(object.producerName); + if (object.producerVersion != null) + message.producerVersion = String(object.producerVersion); + if (object.domain != null) + message.domain = String(object.domain); + if (object.modelVersion != null) + if ($util.Long) + (message.modelVersion = $util.Long.fromValue(object.modelVersion)).unsigned = false; + else if (typeof object.modelVersion === "string") + message.modelVersion = parseInt(object.modelVersion, 10); + else if (typeof object.modelVersion === "number") + message.modelVersion = object.modelVersion; + else if (typeof object.modelVersion === "object") + message.modelVersion = new $util.LongBits(object.modelVersion.low >>> 0, object.modelVersion.high >>> 0).toNumber(); + if (object.docString != null) + message.docString = String(object.docString); + if (object.graph != null) { + if (typeof object.graph !== "object") + throw TypeError(".onnx.ModelProto.graph: object expected"); + message.graph = $root.onnx.GraphProto.fromObject(object.graph); + } + if (object.metadataProps) { + if (!Array.isArray(object.metadataProps)) + throw TypeError(".onnx.ModelProto.metadataProps: array expected"); + message.metadataProps = []; + for (var i = 0; i < object.metadataProps.length; ++i) { + if (typeof object.metadataProps[i] !== "object") + throw TypeError(".onnx.ModelProto.metadataProps: object expected"); + message.metadataProps[i] = $root.onnx.StringStringEntryProto.fromObject(object.metadataProps[i]); + } + } + if (object.trainingInfo) { + if (!Array.isArray(object.trainingInfo)) + throw TypeError(".onnx.ModelProto.trainingInfo: array expected"); + message.trainingInfo = []; + for (var i = 0; i < object.trainingInfo.length; ++i) { + if (typeof object.trainingInfo[i] !== "object") + throw TypeError(".onnx.ModelProto.trainingInfo: object expected"); + message.trainingInfo[i] = $root.onnx.TrainingInfoProto.fromObject(object.trainingInfo[i]); + } + } + if (object.functions) { + if (!Array.isArray(object.functions)) + throw TypeError(".onnx.ModelProto.functions: array expected"); + message.functions = []; + for (var i = 0; i < object.functions.length; ++i) { + if (typeof object.functions[i] !== "object") + throw TypeError(".onnx.ModelProto.functions: object expected"); + message.functions[i] = $root.onnx.FunctionProto.fromObject(object.functions[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a ModelProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.ModelProto + * @static + * @param {onnx.ModelProto} message ModelProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + ModelProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.opsetImport = []; + object.metadataProps = []; + object.trainingInfo = []; + object.functions = []; + } + if (options.defaults) { + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.irVersion = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.irVersion = options.longs === String ? "0" : 0; + object.producerName = ""; + object.producerVersion = ""; + object.domain = ""; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.modelVersion = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.modelVersion = options.longs === String ? "0" : 0; + object.docString = ""; + object.graph = null; + } + if (message.irVersion != null && message.hasOwnProperty("irVersion")) + if (typeof message.irVersion === "number") + object.irVersion = options.longs === String ? String(message.irVersion) : message.irVersion; + else + object.irVersion = options.longs === String ? $util.Long.prototype.toString.call(message.irVersion) : options.longs === Number ? new $util.LongBits(message.irVersion.low >>> 0, message.irVersion.high >>> 0).toNumber() : message.irVersion; + if (message.producerName != null && message.hasOwnProperty("producerName")) + object.producerName = message.producerName; + if (message.producerVersion != null && message.hasOwnProperty("producerVersion")) + object.producerVersion = message.producerVersion; + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.modelVersion != null && message.hasOwnProperty("modelVersion")) + if (typeof message.modelVersion === "number") + object.modelVersion = options.longs === String ? String(message.modelVersion) : message.modelVersion; + else + object.modelVersion = options.longs === String ? $util.Long.prototype.toString.call(message.modelVersion) : options.longs === Number ? new $util.LongBits(message.modelVersion.low >>> 0, message.modelVersion.high >>> 0).toNumber() : message.modelVersion; + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.graph != null && message.hasOwnProperty("graph")) + object.graph = $root.onnx.GraphProto.toObject(message.graph, options); + if (message.opsetImport && message.opsetImport.length) { + object.opsetImport = []; + for (var j = 0; j < message.opsetImport.length; ++j) + object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); + } + if (message.metadataProps && message.metadataProps.length) { + object.metadataProps = []; + for (var j = 0; j < message.metadataProps.length; ++j) + object.metadataProps[j] = $root.onnx.StringStringEntryProto.toObject(message.metadataProps[j], options); + } + if (message.trainingInfo && message.trainingInfo.length) { + object.trainingInfo = []; + for (var j = 0; j < message.trainingInfo.length; ++j) + object.trainingInfo[j] = $root.onnx.TrainingInfoProto.toObject(message.trainingInfo[j], options); + } + if (message.functions && message.functions.length) { + object.functions = []; + for (var j = 0; j < message.functions.length; ++j) + object.functions[j] = $root.onnx.FunctionProto.toObject(message.functions[j], options); + } + return object; + }; + + /** + * Converts this ModelProto to JSON. + * @function toJSON + * @memberof onnx.ModelProto + * @instance + * @returns {Object.} JSON object + */ + ModelProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for ModelProto + * @function getTypeUrl + * @memberof onnx.ModelProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + ModelProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.ModelProto"; + }; + + return ModelProto; + })(); + + onnx.StringStringEntryProto = (function() { + + /** + * Properties of a StringStringEntryProto. + * @memberof onnx + * @interface IStringStringEntryProto + * @property {string|null} [key] StringStringEntryProto key + * @property {string|null} [value] StringStringEntryProto value + */ + + /** + * Constructs a new StringStringEntryProto. + * @memberof onnx + * @classdesc Represents a StringStringEntryProto. + * @implements IStringStringEntryProto + * @constructor + * @param {onnx.IStringStringEntryProto=} [properties] Properties to set + */ + function StringStringEntryProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * StringStringEntryProto key. + * @member {string} key + * @memberof onnx.StringStringEntryProto + * @instance + */ + StringStringEntryProto.prototype.key = ""; + + /** + * StringStringEntryProto value. + * @member {string} value + * @memberof onnx.StringStringEntryProto + * @instance + */ + StringStringEntryProto.prototype.value = ""; + + /** + * Creates a new StringStringEntryProto instance using the specified properties. + * @function create + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto=} [properties] Properties to set + * @returns {onnx.StringStringEntryProto} StringStringEntryProto instance + */ + StringStringEntryProto.create = function create(properties) { + return new StringStringEntryProto(properties); + }; + + /** + * Encodes the specified StringStringEntryProto message. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. + * @function encode + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + StringStringEntryProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.key != null && Object.hasOwnProperty.call(message, "key")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.key); + if (message.value != null && Object.hasOwnProperty.call(message, "value")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.value); + return writer; + }; + + /** + * Encodes the specified StringStringEntryProto message, length delimited. Does not implicitly {@link onnx.StringStringEntryProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.IStringStringEntryProto} message StringStringEntryProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + StringStringEntryProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.StringStringEntryProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + StringStringEntryProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.StringStringEntryProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.key = reader.string(); + break; + } + case 2: { + message.value = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a StringStringEntryProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.StringStringEntryProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + StringStringEntryProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a StringStringEntryProto message. + * @function verify + * @memberof onnx.StringStringEntryProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + StringStringEntryProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.key != null && message.hasOwnProperty("key")) + if (!$util.isString(message.key)) + return "key: string expected"; + if (message.value != null && message.hasOwnProperty("value")) + if (!$util.isString(message.value)) + return "value: string expected"; + return null; + }; + + /** + * Creates a StringStringEntryProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.StringStringEntryProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.StringStringEntryProto} StringStringEntryProto + */ + StringStringEntryProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.StringStringEntryProto) + return object; + var message = new $root.onnx.StringStringEntryProto(); + if (object.key != null) + message.key = String(object.key); + if (object.value != null) + message.value = String(object.value); + return message; + }; + + /** + * Creates a plain object from a StringStringEntryProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.StringStringEntryProto + * @static + * @param {onnx.StringStringEntryProto} message StringStringEntryProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + StringStringEntryProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.key = ""; + object.value = ""; + } + if (message.key != null && message.hasOwnProperty("key")) + object.key = message.key; + if (message.value != null && message.hasOwnProperty("value")) + object.value = message.value; + return object; + }; + + /** + * Converts this StringStringEntryProto to JSON. + * @function toJSON + * @memberof onnx.StringStringEntryProto + * @instance + * @returns {Object.} JSON object + */ + StringStringEntryProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for StringStringEntryProto + * @function getTypeUrl + * @memberof onnx.StringStringEntryProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + StringStringEntryProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.StringStringEntryProto"; + }; + + return StringStringEntryProto; + })(); + + onnx.TensorAnnotation = (function() { + + /** + * Properties of a TensorAnnotation. + * @memberof onnx + * @interface ITensorAnnotation + * @property {string|null} [tensorName] TensorAnnotation tensorName + * @property {Array.|null} [quantParameterTensorNames] TensorAnnotation quantParameterTensorNames + */ + + /** + * Constructs a new TensorAnnotation. + * @memberof onnx + * @classdesc Represents a TensorAnnotation. + * @implements ITensorAnnotation + * @constructor + * @param {onnx.ITensorAnnotation=} [properties] Properties to set + */ + function TensorAnnotation(properties) { + this.quantParameterTensorNames = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorAnnotation tensorName. + * @member {string} tensorName + * @memberof onnx.TensorAnnotation + * @instance + */ + TensorAnnotation.prototype.tensorName = ""; + + /** + * TensorAnnotation quantParameterTensorNames. + * @member {Array.} quantParameterTensorNames + * @memberof onnx.TensorAnnotation + * @instance + */ + TensorAnnotation.prototype.quantParameterTensorNames = $util.emptyArray; + + /** + * Creates a new TensorAnnotation instance using the specified properties. + * @function create + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation=} [properties] Properties to set + * @returns {onnx.TensorAnnotation} TensorAnnotation instance + */ + TensorAnnotation.create = function create(properties) { + return new TensorAnnotation(properties); + }; + + /** + * Encodes the specified TensorAnnotation message. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. + * @function encode + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorAnnotation.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.tensorName != null && Object.hasOwnProperty.call(message, "tensorName")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.tensorName); + if (message.quantParameterTensorNames != null && message.quantParameterTensorNames.length) + for (var i = 0; i < message.quantParameterTensorNames.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.quantParameterTensorNames[i], writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TensorAnnotation message, length delimited. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.ITensorAnnotation} message TensorAnnotation message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorAnnotation.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorAnnotation + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorAnnotation} TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorAnnotation.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorAnnotation(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.tensorName = reader.string(); + break; + } + case 2: { + if (!(message.quantParameterTensorNames && message.quantParameterTensorNames.length)) + message.quantParameterTensorNames = []; + message.quantParameterTensorNames.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorAnnotation message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorAnnotation + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorAnnotation} TensorAnnotation + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorAnnotation.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorAnnotation message. + * @function verify + * @memberof onnx.TensorAnnotation + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorAnnotation.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.tensorName != null && message.hasOwnProperty("tensorName")) + if (!$util.isString(message.tensorName)) + return "tensorName: string expected"; + if (message.quantParameterTensorNames != null && message.hasOwnProperty("quantParameterTensorNames")) { + if (!Array.isArray(message.quantParameterTensorNames)) + return "quantParameterTensorNames: array expected"; + for (var i = 0; i < message.quantParameterTensorNames.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.quantParameterTensorNames[i]); + if (error) + return "quantParameterTensorNames." + error; + } + } + return null; + }; + + /** + * Creates a TensorAnnotation message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorAnnotation + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorAnnotation} TensorAnnotation + */ + TensorAnnotation.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorAnnotation) + return object; + var message = new $root.onnx.TensorAnnotation(); + if (object.tensorName != null) + message.tensorName = String(object.tensorName); + if (object.quantParameterTensorNames) { + if (!Array.isArray(object.quantParameterTensorNames)) + throw TypeError(".onnx.TensorAnnotation.quantParameterTensorNames: array expected"); + message.quantParameterTensorNames = []; + for (var i = 0; i < object.quantParameterTensorNames.length; ++i) { + if (typeof object.quantParameterTensorNames[i] !== "object") + throw TypeError(".onnx.TensorAnnotation.quantParameterTensorNames: object expected"); + message.quantParameterTensorNames[i] = $root.onnx.StringStringEntryProto.fromObject(object.quantParameterTensorNames[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TensorAnnotation message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorAnnotation + * @static + * @param {onnx.TensorAnnotation} message TensorAnnotation + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorAnnotation.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) + object.quantParameterTensorNames = []; + if (options.defaults) + object.tensorName = ""; + if (message.tensorName != null && message.hasOwnProperty("tensorName")) + object.tensorName = message.tensorName; + if (message.quantParameterTensorNames && message.quantParameterTensorNames.length) { + object.quantParameterTensorNames = []; + for (var j = 0; j < message.quantParameterTensorNames.length; ++j) + object.quantParameterTensorNames[j] = $root.onnx.StringStringEntryProto.toObject(message.quantParameterTensorNames[j], options); + } + return object; + }; + + /** + * Converts this TensorAnnotation to JSON. + * @function toJSON + * @memberof onnx.TensorAnnotation + * @instance + * @returns {Object.} JSON object + */ + TensorAnnotation.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorAnnotation + * @function getTypeUrl + * @memberof onnx.TensorAnnotation + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorAnnotation.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorAnnotation"; + }; + + return TensorAnnotation; + })(); + + onnx.GraphProto = (function() { + + /** + * Properties of a GraphProto. + * @memberof onnx + * @interface IGraphProto + * @property {Array.|null} [node] GraphProto node + * @property {string|null} [name] GraphProto name + * @property {Array.|null} [initializer] GraphProto initializer + * @property {Array.|null} [sparseInitializer] GraphProto sparseInitializer + * @property {string|null} [docString] GraphProto docString + * @property {Array.|null} [input] GraphProto input + * @property {Array.|null} [output] GraphProto output + * @property {Array.|null} [valueInfo] GraphProto valueInfo + * @property {Array.|null} [quantizationAnnotation] GraphProto quantizationAnnotation + */ + + /** + * Constructs a new GraphProto. + * @memberof onnx + * @classdesc Represents a GraphProto. + * @implements IGraphProto + * @constructor + * @param {onnx.IGraphProto=} [properties] Properties to set + */ + function GraphProto(properties) { + this.node = []; + this.initializer = []; + this.sparseInitializer = []; + this.input = []; + this.output = []; + this.valueInfo = []; + this.quantizationAnnotation = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * GraphProto node. + * @member {Array.} node + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.node = $util.emptyArray; + + /** + * GraphProto name. + * @member {string} name + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.name = ""; + + /** + * GraphProto initializer. + * @member {Array.} initializer + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.initializer = $util.emptyArray; + + /** + * GraphProto sparseInitializer. + * @member {Array.} sparseInitializer + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.sparseInitializer = $util.emptyArray; + + /** + * GraphProto docString. + * @member {string} docString + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.docString = ""; + + /** + * GraphProto input. + * @member {Array.} input + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.input = $util.emptyArray; + + /** + * GraphProto output. + * @member {Array.} output + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.output = $util.emptyArray; + + /** + * GraphProto valueInfo. + * @member {Array.} valueInfo + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.valueInfo = $util.emptyArray; + + /** + * GraphProto quantizationAnnotation. + * @member {Array.} quantizationAnnotation + * @memberof onnx.GraphProto + * @instance + */ + GraphProto.prototype.quantizationAnnotation = $util.emptyArray; + + /** + * Creates a new GraphProto instance using the specified properties. + * @function create + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto=} [properties] Properties to set + * @returns {onnx.GraphProto} GraphProto instance + */ + GraphProto.create = function create(properties) { + return new GraphProto(properties); + }; + + /** + * Encodes the specified GraphProto message. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @function encode + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto} message GraphProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + GraphProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.node != null && message.node.length) + for (var i = 0; i < message.node.length; ++i) + $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.name); + if (message.initializer != null && message.initializer.length) + for (var i = 0; i < message.initializer.length; ++i) + $root.onnx.TensorProto.encode(message.initializer[i], writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 10, wireType 2 =*/82).string(message.docString); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + $root.onnx.ValueInfoProto.encode(message.input[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + $root.onnx.ValueInfoProto.encode(message.output[i], writer.uint32(/* id 12, wireType 2 =*/98).fork()).ldelim(); + if (message.valueInfo != null && message.valueInfo.length) + for (var i = 0; i < message.valueInfo.length; ++i) + $root.onnx.ValueInfoProto.encode(message.valueInfo[i], writer.uint32(/* id 13, wireType 2 =*/106).fork()).ldelim(); + if (message.quantizationAnnotation != null && message.quantizationAnnotation.length) + for (var i = 0; i < message.quantizationAnnotation.length; ++i) + $root.onnx.TensorAnnotation.encode(message.quantizationAnnotation[i], writer.uint32(/* id 14, wireType 2 =*/114).fork()).ldelim(); + if (message.sparseInitializer != null && message.sparseInitializer.length) + for (var i = 0; i < message.sparseInitializer.length; ++i) + $root.onnx.SparseTensorProto.encode(message.sparseInitializer[i], writer.uint32(/* id 15, wireType 2 =*/122).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified GraphProto message, length delimited. Does not implicitly {@link onnx.GraphProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.GraphProto + * @static + * @param {onnx.IGraphProto} message GraphProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + GraphProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a GraphProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.GraphProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.GraphProto} GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + GraphProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.GraphProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.node && message.node.length)) + message.node = []; + message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); + break; + } + case 2: { + message.name = reader.string(); + break; + } + case 5: { + if (!(message.initializer && message.initializer.length)) + message.initializer = []; + message.initializer.push($root.onnx.TensorProto.decode(reader, reader.uint32())); + break; + } + case 15: { + if (!(message.sparseInitializer && message.sparseInitializer.length)) + message.sparseInitializer = []; + message.sparseInitializer.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); + break; + } + case 10: { + message.docString = reader.string(); + break; + } + case 11: { + if (!(message.input && message.input.length)) + message.input = []; + message.input.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 12: { + if (!(message.output && message.output.length)) + message.output = []; + message.output.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 13: { + if (!(message.valueInfo && message.valueInfo.length)) + message.valueInfo = []; + message.valueInfo.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + break; + } + case 14: { + if (!(message.quantizationAnnotation && message.quantizationAnnotation.length)) + message.quantizationAnnotation = []; + message.quantizationAnnotation.push($root.onnx.TensorAnnotation.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a GraphProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.GraphProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.GraphProto} GraphProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + GraphProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a GraphProto message. + * @function verify + * @memberof onnx.GraphProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + GraphProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.node != null && message.hasOwnProperty("node")) { + if (!Array.isArray(message.node)) + return "node: array expected"; + for (var i = 0; i < message.node.length; ++i) { + var error = $root.onnx.NodeProto.verify(message.node[i]); + if (error) + return "node." + error; + } + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.initializer != null && message.hasOwnProperty("initializer")) { + if (!Array.isArray(message.initializer)) + return "initializer: array expected"; + for (var i = 0; i < message.initializer.length; ++i) { + var error = $root.onnx.TensorProto.verify(message.initializer[i]); + if (error) + return "initializer." + error; + } + } + if (message.sparseInitializer != null && message.hasOwnProperty("sparseInitializer")) { + if (!Array.isArray(message.sparseInitializer)) + return "sparseInitializer: array expected"; + for (var i = 0; i < message.sparseInitializer.length; ++i) { + var error = $root.onnx.SparseTensorProto.verify(message.sparseInitializer[i]); + if (error) + return "sparseInitializer." + error; + } + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) + return "input: array expected"; + for (var i = 0; i < message.input.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.input[i]); + if (error) + return "input." + error; + } + } + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) + return "output: array expected"; + for (var i = 0; i < message.output.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.output[i]); + if (error) + return "output." + error; + } + } + if (message.valueInfo != null && message.hasOwnProperty("valueInfo")) { + if (!Array.isArray(message.valueInfo)) + return "valueInfo: array expected"; + for (var i = 0; i < message.valueInfo.length; ++i) { + var error = $root.onnx.ValueInfoProto.verify(message.valueInfo[i]); + if (error) + return "valueInfo." + error; + } + } + if (message.quantizationAnnotation != null && message.hasOwnProperty("quantizationAnnotation")) { + if (!Array.isArray(message.quantizationAnnotation)) + return "quantizationAnnotation: array expected"; + for (var i = 0; i < message.quantizationAnnotation.length; ++i) { + var error = $root.onnx.TensorAnnotation.verify(message.quantizationAnnotation[i]); + if (error) + return "quantizationAnnotation." + error; + } + } + return null; + }; + + /** + * Creates a GraphProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.GraphProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.GraphProto} GraphProto + */ + GraphProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.GraphProto) + return object; + var message = new $root.onnx.GraphProto(); + if (object.node) { + if (!Array.isArray(object.node)) + throw TypeError(".onnx.GraphProto.node: array expected"); + message.node = []; + for (var i = 0; i < object.node.length; ++i) { + if (typeof object.node[i] !== "object") + throw TypeError(".onnx.GraphProto.node: object expected"); + message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); + } + } + if (object.name != null) + message.name = String(object.name); + if (object.initializer) { + if (!Array.isArray(object.initializer)) + throw TypeError(".onnx.GraphProto.initializer: array expected"); + message.initializer = []; + for (var i = 0; i < object.initializer.length; ++i) { + if (typeof object.initializer[i] !== "object") + throw TypeError(".onnx.GraphProto.initializer: object expected"); + message.initializer[i] = $root.onnx.TensorProto.fromObject(object.initializer[i]); + } + } + if (object.sparseInitializer) { + if (!Array.isArray(object.sparseInitializer)) + throw TypeError(".onnx.GraphProto.sparseInitializer: array expected"); + message.sparseInitializer = []; + for (var i = 0; i < object.sparseInitializer.length; ++i) { + if (typeof object.sparseInitializer[i] !== "object") + throw TypeError(".onnx.GraphProto.sparseInitializer: object expected"); + message.sparseInitializer[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseInitializer[i]); + } + } + if (object.docString != null) + message.docString = String(object.docString); + if (object.input) { + if (!Array.isArray(object.input)) + throw TypeError(".onnx.GraphProto.input: array expected"); + message.input = []; + for (var i = 0; i < object.input.length; ++i) { + if (typeof object.input[i] !== "object") + throw TypeError(".onnx.GraphProto.input: object expected"); + message.input[i] = $root.onnx.ValueInfoProto.fromObject(object.input[i]); + } + } + if (object.output) { + if (!Array.isArray(object.output)) + throw TypeError(".onnx.GraphProto.output: array expected"); + message.output = []; + for (var i = 0; i < object.output.length; ++i) { + if (typeof object.output[i] !== "object") + throw TypeError(".onnx.GraphProto.output: object expected"); + message.output[i] = $root.onnx.ValueInfoProto.fromObject(object.output[i]); + } + } + if (object.valueInfo) { + if (!Array.isArray(object.valueInfo)) + throw TypeError(".onnx.GraphProto.valueInfo: array expected"); + message.valueInfo = []; + for (var i = 0; i < object.valueInfo.length; ++i) { + if (typeof object.valueInfo[i] !== "object") + throw TypeError(".onnx.GraphProto.valueInfo: object expected"); + message.valueInfo[i] = $root.onnx.ValueInfoProto.fromObject(object.valueInfo[i]); + } + } + if (object.quantizationAnnotation) { + if (!Array.isArray(object.quantizationAnnotation)) + throw TypeError(".onnx.GraphProto.quantizationAnnotation: array expected"); + message.quantizationAnnotation = []; + for (var i = 0; i < object.quantizationAnnotation.length; ++i) { + if (typeof object.quantizationAnnotation[i] !== "object") + throw TypeError(".onnx.GraphProto.quantizationAnnotation: object expected"); + message.quantizationAnnotation[i] = $root.onnx.TensorAnnotation.fromObject(object.quantizationAnnotation[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a GraphProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.GraphProto + * @static + * @param {onnx.GraphProto} message GraphProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + GraphProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.node = []; + object.initializer = []; + object.input = []; + object.output = []; + object.valueInfo = []; + object.quantizationAnnotation = []; + object.sparseInitializer = []; + } + if (options.defaults) { + object.name = ""; + object.docString = ""; + } + if (message.node && message.node.length) { + object.node = []; + for (var j = 0; j < message.node.length; ++j) + object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.initializer && message.initializer.length) { + object.initializer = []; + for (var j = 0; j < message.initializer.length; ++j) + object.initializer[j] = $root.onnx.TensorProto.toObject(message.initializer[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = $root.onnx.ValueInfoProto.toObject(message.input[j], options); + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = $root.onnx.ValueInfoProto.toObject(message.output[j], options); + } + if (message.valueInfo && message.valueInfo.length) { + object.valueInfo = []; + for (var j = 0; j < message.valueInfo.length; ++j) + object.valueInfo[j] = $root.onnx.ValueInfoProto.toObject(message.valueInfo[j], options); + } + if (message.quantizationAnnotation && message.quantizationAnnotation.length) { + object.quantizationAnnotation = []; + for (var j = 0; j < message.quantizationAnnotation.length; ++j) + object.quantizationAnnotation[j] = $root.onnx.TensorAnnotation.toObject(message.quantizationAnnotation[j], options); + } + if (message.sparseInitializer && message.sparseInitializer.length) { + object.sparseInitializer = []; + for (var j = 0; j < message.sparseInitializer.length; ++j) + object.sparseInitializer[j] = $root.onnx.SparseTensorProto.toObject(message.sparseInitializer[j], options); + } + return object; + }; + + /** + * Converts this GraphProto to JSON. + * @function toJSON + * @memberof onnx.GraphProto + * @instance + * @returns {Object.} JSON object + */ + GraphProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for GraphProto + * @function getTypeUrl + * @memberof onnx.GraphProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + GraphProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.GraphProto"; + }; + + return GraphProto; + })(); + + onnx.TensorProto = (function() { + + /** + * Properties of a TensorProto. + * @memberof onnx + * @interface ITensorProto + * @property {Array.|null} [dims] TensorProto dims + * @property {number|null} [dataType] TensorProto dataType + * @property {onnx.TensorProto.ISegment|null} [segment] TensorProto segment + * @property {Array.|null} [floatData] TensorProto floatData + * @property {Array.|null} [int32Data] TensorProto int32Data + * @property {Array.|null} [stringData] TensorProto stringData + * @property {Array.|null} [int64Data] TensorProto int64Data + * @property {string|null} [name] TensorProto name + * @property {string|null} [docString] TensorProto docString + * @property {Uint8Array|null} [rawData] TensorProto rawData + * @property {Array.|null} [externalData] TensorProto externalData + * @property {onnx.TensorProto.DataLocation|null} [dataLocation] TensorProto dataLocation + * @property {Array.|null} [doubleData] TensorProto doubleData + * @property {Array.|null} [uint64Data] TensorProto uint64Data + */ + + /** + * Constructs a new TensorProto. + * @memberof onnx + * @classdesc Represents a TensorProto. + * @implements ITensorProto + * @constructor + * @param {onnx.ITensorProto=} [properties] Properties to set + */ + function TensorProto(properties) { + this.dims = []; + this.floatData = []; + this.int32Data = []; + this.stringData = []; + this.int64Data = []; + this.externalData = []; + this.doubleData = []; + this.uint64Data = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorProto dims. + * @member {Array.} dims + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dims = $util.emptyArray; + + /** + * TensorProto dataType. + * @member {number} dataType + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dataType = 0; + + /** + * TensorProto segment. + * @member {onnx.TensorProto.ISegment|null|undefined} segment + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.segment = null; + + /** + * TensorProto floatData. + * @member {Array.} floatData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.floatData = $util.emptyArray; + + /** + * TensorProto int32Data. + * @member {Array.} int32Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.int32Data = $util.emptyArray; + + /** + * TensorProto stringData. + * @member {Array.} stringData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.stringData = $util.emptyArray; + + /** + * TensorProto int64Data. + * @member {Array.} int64Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.int64Data = $util.emptyArray; + + /** + * TensorProto name. + * @member {string} name + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.name = ""; + + /** + * TensorProto docString. + * @member {string} docString + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.docString = ""; + + /** + * TensorProto rawData. + * @member {Uint8Array} rawData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.rawData = $util.newBuffer([]); + + /** + * TensorProto externalData. + * @member {Array.} externalData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.externalData = $util.emptyArray; + + /** + * TensorProto dataLocation. + * @member {onnx.TensorProto.DataLocation} dataLocation + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.dataLocation = 0; + + /** + * TensorProto doubleData. + * @member {Array.} doubleData + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.doubleData = $util.emptyArray; + + /** + * TensorProto uint64Data. + * @member {Array.} uint64Data + * @memberof onnx.TensorProto + * @instance + */ + TensorProto.prototype.uint64Data = $util.emptyArray; + + /** + * Creates a new TensorProto instance using the specified properties. + * @function create + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto=} [properties] Properties to set + * @returns {onnx.TensorProto} TensorProto instance + */ + TensorProto.create = function create(properties) { + return new TensorProto(properties); + }; + + /** + * Encodes the specified TensorProto message. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @function encode + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto} message TensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.dims != null && message.dims.length) { + writer.uint32(/* id 1, wireType 2 =*/10).fork(); + for (var i = 0; i < message.dims.length; ++i) + writer.int64(message.dims[i]); + writer.ldelim(); + } + if (message.dataType != null && Object.hasOwnProperty.call(message, "dataType")) + writer.uint32(/* id 2, wireType 0 =*/16).int32(message.dataType); + if (message.segment != null && Object.hasOwnProperty.call(message, "segment")) + $root.onnx.TensorProto.Segment.encode(message.segment, writer.uint32(/* id 3, wireType 2 =*/26).fork()).ldelim(); + if (message.floatData != null && message.floatData.length) { + writer.uint32(/* id 4, wireType 2 =*/34).fork(); + for (var i = 0; i < message.floatData.length; ++i) + writer.float(message.floatData[i]); + writer.ldelim(); + } + if (message.int32Data != null && message.int32Data.length) { + writer.uint32(/* id 5, wireType 2 =*/42).fork(); + for (var i = 0; i < message.int32Data.length; ++i) + writer.int32(message.int32Data[i]); + writer.ldelim(); + } + if (message.stringData != null && message.stringData.length) + for (var i = 0; i < message.stringData.length; ++i) + writer.uint32(/* id 6, wireType 2 =*/50).bytes(message.stringData[i]); + if (message.int64Data != null && message.int64Data.length) { + writer.uint32(/* id 7, wireType 2 =*/58).fork(); + for (var i = 0; i < message.int64Data.length; ++i) + writer.int64(message.int64Data[i]); + writer.ldelim(); + } + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 8, wireType 2 =*/66).string(message.name); + if (message.rawData != null && Object.hasOwnProperty.call(message, "rawData")) + writer.uint32(/* id 9, wireType 2 =*/74).bytes(message.rawData); + if (message.doubleData != null && message.doubleData.length) { + writer.uint32(/* id 10, wireType 2 =*/82).fork(); + for (var i = 0; i < message.doubleData.length; ++i) + writer.double(message.doubleData[i]); + writer.ldelim(); + } + if (message.uint64Data != null && message.uint64Data.length) { + writer.uint32(/* id 11, wireType 2 =*/90).fork(); + for (var i = 0; i < message.uint64Data.length; ++i) + writer.uint64(message.uint64Data[i]); + writer.ldelim(); + } + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 12, wireType 2 =*/98).string(message.docString); + if (message.externalData != null && message.externalData.length) + for (var i = 0; i < message.externalData.length; ++i) + $root.onnx.StringStringEntryProto.encode(message.externalData[i], writer.uint32(/* id 13, wireType 2 =*/106).fork()).ldelim(); + if (message.dataLocation != null && Object.hasOwnProperty.call(message, "dataLocation")) + writer.uint32(/* id 14, wireType 0 =*/112).int32(message.dataLocation); + return writer; + }; + + /** + * Encodes the specified TensorProto message, length delimited. Does not implicitly {@link onnx.TensorProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorProto + * @static + * @param {onnx.ITensorProto} message TensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorProto} TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.dims && message.dims.length)) + message.dims = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.dims.push(reader.int64()); + } else + message.dims.push(reader.int64()); + break; + } + case 2: { + message.dataType = reader.int32(); + break; + } + case 3: { + message.segment = $root.onnx.TensorProto.Segment.decode(reader, reader.uint32()); + break; + } + case 4: { + if (!(message.floatData && message.floatData.length)) + message.floatData = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.floatData.push(reader.float()); + } else + message.floatData.push(reader.float()); + break; + } + case 5: { + if (!(message.int32Data && message.int32Data.length)) + message.int32Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.int32Data.push(reader.int32()); + } else + message.int32Data.push(reader.int32()); + break; + } + case 6: { + if (!(message.stringData && message.stringData.length)) + message.stringData = []; + message.stringData.push(reader.bytes()); + break; + } + case 7: { + if (!(message.int64Data && message.int64Data.length)) + message.int64Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.int64Data.push(reader.int64()); + } else + message.int64Data.push(reader.int64()); + break; + } + case 8: { + message.name = reader.string(); + break; + } + case 12: { + message.docString = reader.string(); + break; + } + case 9: { + message.rawData = reader.bytes(); + break; + } + case 13: { + if (!(message.externalData && message.externalData.length)) + message.externalData = []; + message.externalData.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + break; + } + case 14: { + message.dataLocation = reader.int32(); + break; + } + case 10: { + if (!(message.doubleData && message.doubleData.length)) + message.doubleData = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.doubleData.push(reader.double()); + } else + message.doubleData.push(reader.double()); + break; + } + case 11: { + if (!(message.uint64Data && message.uint64Data.length)) + message.uint64Data = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.uint64Data.push(reader.uint64()); + } else + message.uint64Data.push(reader.uint64()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorProto} TensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorProto message. + * @function verify + * @memberof onnx.TensorProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.dims != null && message.hasOwnProperty("dims")) { + if (!Array.isArray(message.dims)) + return "dims: array expected"; + for (var i = 0; i < message.dims.length; ++i) + if (!$util.isInteger(message.dims[i]) && !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high))) + return "dims: integer|Long[] expected"; + } + if (message.dataType != null && message.hasOwnProperty("dataType")) + if (!$util.isInteger(message.dataType)) + return "dataType: integer expected"; + if (message.segment != null && message.hasOwnProperty("segment")) { + var error = $root.onnx.TensorProto.Segment.verify(message.segment); + if (error) + return "segment." + error; + } + if (message.floatData != null && message.hasOwnProperty("floatData")) { + if (!Array.isArray(message.floatData)) + return "floatData: array expected"; + for (var i = 0; i < message.floatData.length; ++i) + if (typeof message.floatData[i] !== "number") + return "floatData: number[] expected"; + } + if (message.int32Data != null && message.hasOwnProperty("int32Data")) { + if (!Array.isArray(message.int32Data)) + return "int32Data: array expected"; + for (var i = 0; i < message.int32Data.length; ++i) + if (!$util.isInteger(message.int32Data[i])) + return "int32Data: integer[] expected"; + } + if (message.stringData != null && message.hasOwnProperty("stringData")) { + if (!Array.isArray(message.stringData)) + return "stringData: array expected"; + for (var i = 0; i < message.stringData.length; ++i) + if (!(message.stringData[i] && typeof message.stringData[i].length === "number" || $util.isString(message.stringData[i]))) + return "stringData: buffer[] expected"; + } + if (message.int64Data != null && message.hasOwnProperty("int64Data")) { + if (!Array.isArray(message.int64Data)) + return "int64Data: array expected"; + for (var i = 0; i < message.int64Data.length; ++i) + if (!$util.isInteger(message.int64Data[i]) && !(message.int64Data[i] && $util.isInteger(message.int64Data[i].low) && $util.isInteger(message.int64Data[i].high))) + return "int64Data: integer|Long[] expected"; + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.rawData != null && message.hasOwnProperty("rawData")) + if (!(message.rawData && typeof message.rawData.length === "number" || $util.isString(message.rawData))) + return "rawData: buffer expected"; + if (message.externalData != null && message.hasOwnProperty("externalData")) { + if (!Array.isArray(message.externalData)) + return "externalData: array expected"; + for (var i = 0; i < message.externalData.length; ++i) { + var error = $root.onnx.StringStringEntryProto.verify(message.externalData[i]); + if (error) + return "externalData." + error; + } + } + if (message.dataLocation != null && message.hasOwnProperty("dataLocation")) + switch (message.dataLocation) { + default: + return "dataLocation: enum value expected"; + case 0: + case 1: + break; + } + if (message.doubleData != null && message.hasOwnProperty("doubleData")) { + if (!Array.isArray(message.doubleData)) + return "doubleData: array expected"; + for (var i = 0; i < message.doubleData.length; ++i) + if (typeof message.doubleData[i] !== "number") + return "doubleData: number[] expected"; + } + if (message.uint64Data != null && message.hasOwnProperty("uint64Data")) { + if (!Array.isArray(message.uint64Data)) + return "uint64Data: array expected"; + for (var i = 0; i < message.uint64Data.length; ++i) + if (!$util.isInteger(message.uint64Data[i]) && !(message.uint64Data[i] && $util.isInteger(message.uint64Data[i].low) && $util.isInteger(message.uint64Data[i].high))) + return "uint64Data: integer|Long[] expected"; + } + return null; + }; + + /** + * Creates a TensorProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorProto} TensorProto + */ + TensorProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorProto) + return object; + var message = new $root.onnx.TensorProto(); + if (object.dims) { + if (!Array.isArray(object.dims)) + throw TypeError(".onnx.TensorProto.dims: array expected"); + message.dims = []; + for (var i = 0; i < object.dims.length; ++i) + if ($util.Long) + (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; + else if (typeof object.dims[i] === "string") + message.dims[i] = parseInt(object.dims[i], 10); + else if (typeof object.dims[i] === "number") + message.dims[i] = object.dims[i]; + else if (typeof object.dims[i] === "object") + message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); + } + if (object.dataType != null) + message.dataType = object.dataType | 0; + if (object.segment != null) { + if (typeof object.segment !== "object") + throw TypeError(".onnx.TensorProto.segment: object expected"); + message.segment = $root.onnx.TensorProto.Segment.fromObject(object.segment); + } + if (object.floatData) { + if (!Array.isArray(object.floatData)) + throw TypeError(".onnx.TensorProto.floatData: array expected"); + message.floatData = []; + for (var i = 0; i < object.floatData.length; ++i) + message.floatData[i] = Number(object.floatData[i]); + } + if (object.int32Data) { + if (!Array.isArray(object.int32Data)) + throw TypeError(".onnx.TensorProto.int32Data: array expected"); + message.int32Data = []; + for (var i = 0; i < object.int32Data.length; ++i) + message.int32Data[i] = object.int32Data[i] | 0; + } + if (object.stringData) { + if (!Array.isArray(object.stringData)) + throw TypeError(".onnx.TensorProto.stringData: array expected"); + message.stringData = []; + for (var i = 0; i < object.stringData.length; ++i) + if (typeof object.stringData[i] === "string") + $util.base64.decode(object.stringData[i], message.stringData[i] = $util.newBuffer($util.base64.length(object.stringData[i])), 0); + else if (object.stringData[i].length >= 0) + message.stringData[i] = object.stringData[i]; + } + if (object.int64Data) { + if (!Array.isArray(object.int64Data)) + throw TypeError(".onnx.TensorProto.int64Data: array expected"); + message.int64Data = []; + for (var i = 0; i < object.int64Data.length; ++i) + if ($util.Long) + (message.int64Data[i] = $util.Long.fromValue(object.int64Data[i])).unsigned = false; + else if (typeof object.int64Data[i] === "string") + message.int64Data[i] = parseInt(object.int64Data[i], 10); + else if (typeof object.int64Data[i] === "number") + message.int64Data[i] = object.int64Data[i]; + else if (typeof object.int64Data[i] === "object") + message.int64Data[i] = new $util.LongBits(object.int64Data[i].low >>> 0, object.int64Data[i].high >>> 0).toNumber(); + } + if (object.name != null) + message.name = String(object.name); + if (object.docString != null) + message.docString = String(object.docString); + if (object.rawData != null) + if (typeof object.rawData === "string") + $util.base64.decode(object.rawData, message.rawData = $util.newBuffer($util.base64.length(object.rawData)), 0); + else if (object.rawData.length >= 0) + message.rawData = object.rawData; + if (object.externalData) { + if (!Array.isArray(object.externalData)) + throw TypeError(".onnx.TensorProto.externalData: array expected"); + message.externalData = []; + for (var i = 0; i < object.externalData.length; ++i) { + if (typeof object.externalData[i] !== "object") + throw TypeError(".onnx.TensorProto.externalData: object expected"); + message.externalData[i] = $root.onnx.StringStringEntryProto.fromObject(object.externalData[i]); + } + } + switch (object.dataLocation) { + default: + if (typeof object.dataLocation === "number") { + message.dataLocation = object.dataLocation; + break; + } + break; + case "DEFAULT": + case 0: + message.dataLocation = 0; + break; + case "EXTERNAL": + case 1: + message.dataLocation = 1; + break; + } + if (object.doubleData) { + if (!Array.isArray(object.doubleData)) + throw TypeError(".onnx.TensorProto.doubleData: array expected"); + message.doubleData = []; + for (var i = 0; i < object.doubleData.length; ++i) + message.doubleData[i] = Number(object.doubleData[i]); + } + if (object.uint64Data) { + if (!Array.isArray(object.uint64Data)) + throw TypeError(".onnx.TensorProto.uint64Data: array expected"); + message.uint64Data = []; + for (var i = 0; i < object.uint64Data.length; ++i) + if ($util.Long) + (message.uint64Data[i] = $util.Long.fromValue(object.uint64Data[i])).unsigned = true; + else if (typeof object.uint64Data[i] === "string") + message.uint64Data[i] = parseInt(object.uint64Data[i], 10); + else if (typeof object.uint64Data[i] === "number") + message.uint64Data[i] = object.uint64Data[i]; + else if (typeof object.uint64Data[i] === "object") + message.uint64Data[i] = new $util.LongBits(object.uint64Data[i].low >>> 0, object.uint64Data[i].high >>> 0).toNumber(true); + } + return message; + }; + + /** + * Creates a plain object from a TensorProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorProto + * @static + * @param {onnx.TensorProto} message TensorProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.dims = []; + object.floatData = []; + object.int32Data = []; + object.stringData = []; + object.int64Data = []; + object.doubleData = []; + object.uint64Data = []; + object.externalData = []; + } + if (options.defaults) { + object.dataType = 0; + object.segment = null; + object.name = ""; + if (options.bytes === String) + object.rawData = ""; + else { + object.rawData = []; + if (options.bytes !== Array) + object.rawData = $util.newBuffer(object.rawData); + } + object.docString = ""; + object.dataLocation = options.enums === String ? "DEFAULT" : 0; + } + if (message.dims && message.dims.length) { + object.dims = []; + for (var j = 0; j < message.dims.length; ++j) + if (typeof message.dims[j] === "number") + object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; + else + object.dims[j] = options.longs === String ? $util.Long.prototype.toString.call(message.dims[j]) : options.longs === Number ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() : message.dims[j]; + } + if (message.dataType != null && message.hasOwnProperty("dataType")) + object.dataType = message.dataType; + if (message.segment != null && message.hasOwnProperty("segment")) + object.segment = $root.onnx.TensorProto.Segment.toObject(message.segment, options); + if (message.floatData && message.floatData.length) { + object.floatData = []; + for (var j = 0; j < message.floatData.length; ++j) + object.floatData[j] = options.json && !isFinite(message.floatData[j]) ? String(message.floatData[j]) : message.floatData[j]; + } + if (message.int32Data && message.int32Data.length) { + object.int32Data = []; + for (var j = 0; j < message.int32Data.length; ++j) + object.int32Data[j] = message.int32Data[j]; + } + if (message.stringData && message.stringData.length) { + object.stringData = []; + for (var j = 0; j < message.stringData.length; ++j) + object.stringData[j] = options.bytes === String ? $util.base64.encode(message.stringData[j], 0, message.stringData[j].length) : options.bytes === Array ? Array.prototype.slice.call(message.stringData[j]) : message.stringData[j]; + } + if (message.int64Data && message.int64Data.length) { + object.int64Data = []; + for (var j = 0; j < message.int64Data.length; ++j) + if (typeof message.int64Data[j] === "number") + object.int64Data[j] = options.longs === String ? String(message.int64Data[j]) : message.int64Data[j]; + else + object.int64Data[j] = options.longs === String ? $util.Long.prototype.toString.call(message.int64Data[j]) : options.longs === Number ? new $util.LongBits(message.int64Data[j].low >>> 0, message.int64Data[j].high >>> 0).toNumber() : message.int64Data[j]; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.rawData != null && message.hasOwnProperty("rawData")) + object.rawData = options.bytes === String ? $util.base64.encode(message.rawData, 0, message.rawData.length) : options.bytes === Array ? Array.prototype.slice.call(message.rawData) : message.rawData; + if (message.doubleData && message.doubleData.length) { + object.doubleData = []; + for (var j = 0; j < message.doubleData.length; ++j) + object.doubleData[j] = options.json && !isFinite(message.doubleData[j]) ? String(message.doubleData[j]) : message.doubleData[j]; + } + if (message.uint64Data && message.uint64Data.length) { + object.uint64Data = []; + for (var j = 0; j < message.uint64Data.length; ++j) + if (typeof message.uint64Data[j] === "number") + object.uint64Data[j] = options.longs === String ? String(message.uint64Data[j]) : message.uint64Data[j]; + else + object.uint64Data[j] = options.longs === String ? $util.Long.prototype.toString.call(message.uint64Data[j]) : options.longs === Number ? new $util.LongBits(message.uint64Data[j].low >>> 0, message.uint64Data[j].high >>> 0).toNumber(true) : message.uint64Data[j]; + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.externalData && message.externalData.length) { + object.externalData = []; + for (var j = 0; j < message.externalData.length; ++j) + object.externalData[j] = $root.onnx.StringStringEntryProto.toObject(message.externalData[j], options); + } + if (message.dataLocation != null && message.hasOwnProperty("dataLocation")) + object.dataLocation = options.enums === String ? $root.onnx.TensorProto.DataLocation[message.dataLocation] === undefined ? message.dataLocation : $root.onnx.TensorProto.DataLocation[message.dataLocation] : message.dataLocation; + return object; + }; + + /** + * Converts this TensorProto to JSON. + * @function toJSON + * @memberof onnx.TensorProto + * @instance + * @returns {Object.} JSON object + */ + TensorProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorProto + * @function getTypeUrl + * @memberof onnx.TensorProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorProto"; + }; + + /** + * DataType enum. + * @name onnx.TensorProto.DataType + * @enum {number} + * @property {number} UNDEFINED=0 UNDEFINED value + * @property {number} FLOAT=1 FLOAT value + * @property {number} UINT8=2 UINT8 value + * @property {number} INT8=3 INT8 value + * @property {number} UINT16=4 UINT16 value + * @property {number} INT16=5 INT16 value + * @property {number} INT32=6 INT32 value + * @property {number} INT64=7 INT64 value + * @property {number} STRING=8 STRING value + * @property {number} BOOL=9 BOOL value + * @property {number} FLOAT16=10 FLOAT16 value + * @property {number} DOUBLE=11 DOUBLE value + * @property {number} UINT32=12 UINT32 value + * @property {number} UINT64=13 UINT64 value + * @property {number} COMPLEX64=14 COMPLEX64 value + * @property {number} COMPLEX128=15 COMPLEX128 value + * @property {number} BFLOAT16=16 BFLOAT16 value + * @property {number} FLOAT8E4M3FN=17 FLOAT8E4M3FN value + * @property {number} FLOAT8E4M3FNUZ=18 FLOAT8E4M3FNUZ value + * @property {number} FLOAT8E5M2=19 FLOAT8E5M2 value + * @property {number} FLOAT8E5M2FNUZ=20 FLOAT8E5M2FNUZ value + */ + TensorProto.DataType = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "UNDEFINED"] = 0; + values[valuesById[1] = "FLOAT"] = 1; + values[valuesById[2] = "UINT8"] = 2; + values[valuesById[3] = "INT8"] = 3; + values[valuesById[4] = "UINT16"] = 4; + values[valuesById[5] = "INT16"] = 5; + values[valuesById[6] = "INT32"] = 6; + values[valuesById[7] = "INT64"] = 7; + values[valuesById[8] = "STRING"] = 8; + values[valuesById[9] = "BOOL"] = 9; + values[valuesById[10] = "FLOAT16"] = 10; + values[valuesById[11] = "DOUBLE"] = 11; + values[valuesById[12] = "UINT32"] = 12; + values[valuesById[13] = "UINT64"] = 13; + values[valuesById[14] = "COMPLEX64"] = 14; + values[valuesById[15] = "COMPLEX128"] = 15; + values[valuesById[16] = "BFLOAT16"] = 16; + values[valuesById[17] = "FLOAT8E4M3FN"] = 17; + values[valuesById[18] = "FLOAT8E4M3FNUZ"] = 18; + values[valuesById[19] = "FLOAT8E5M2"] = 19; + values[valuesById[20] = "FLOAT8E5M2FNUZ"] = 20; + return values; + })(); + + TensorProto.Segment = (function() { + + /** + * Properties of a Segment. + * @memberof onnx.TensorProto + * @interface ISegment + * @property {number|Long|null} [begin] Segment begin + * @property {number|Long|null} [end] Segment end + */ + + /** + * Constructs a new Segment. + * @memberof onnx.TensorProto + * @classdesc Represents a Segment. + * @implements ISegment + * @constructor + * @param {onnx.TensorProto.ISegment=} [properties] Properties to set + */ + function Segment(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Segment begin. + * @member {number|Long} begin + * @memberof onnx.TensorProto.Segment + * @instance + */ + Segment.prototype.begin = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * Segment end. + * @member {number|Long} end + * @memberof onnx.TensorProto.Segment + * @instance + */ + Segment.prototype.end = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * Creates a new Segment instance using the specified properties. + * @function create + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment=} [properties] Properties to set + * @returns {onnx.TensorProto.Segment} Segment instance + */ + Segment.create = function create(properties) { + return new Segment(properties); + }; + + /** + * Encodes the specified Segment message. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. + * @function encode + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Segment.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.begin != null && Object.hasOwnProperty.call(message, "begin")) + writer.uint32(/* id 1, wireType 0 =*/8).int64(message.begin); + if (message.end != null && Object.hasOwnProperty.call(message, "end")) + writer.uint32(/* id 2, wireType 0 =*/16).int64(message.end); + return writer; + }; + + /** + * Encodes the specified Segment message, length delimited. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.ISegment} message Segment message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Segment.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Segment message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorProto.Segment + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorProto.Segment} Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Segment.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorProto.Segment(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.begin = reader.int64(); + break; + } + case 2: { + message.end = reader.int64(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Segment message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorProto.Segment + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorProto.Segment} Segment + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Segment.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Segment message. + * @function verify + * @memberof onnx.TensorProto.Segment + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Segment.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.begin != null && message.hasOwnProperty("begin")) + if (!$util.isInteger(message.begin) && !(message.begin && $util.isInteger(message.begin.low) && $util.isInteger(message.begin.high))) + return "begin: integer|Long expected"; + if (message.end != null && message.hasOwnProperty("end")) + if (!$util.isInteger(message.end) && !(message.end && $util.isInteger(message.end.low) && $util.isInteger(message.end.high))) + return "end: integer|Long expected"; + return null; + }; + + /** + * Creates a Segment message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorProto.Segment + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorProto.Segment} Segment + */ + Segment.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorProto.Segment) + return object; + var message = new $root.onnx.TensorProto.Segment(); + if (object.begin != null) + if ($util.Long) + (message.begin = $util.Long.fromValue(object.begin)).unsigned = false; + else if (typeof object.begin === "string") + message.begin = parseInt(object.begin, 10); + else if (typeof object.begin === "number") + message.begin = object.begin; + else if (typeof object.begin === "object") + message.begin = new $util.LongBits(object.begin.low >>> 0, object.begin.high >>> 0).toNumber(); + if (object.end != null) + if ($util.Long) + (message.end = $util.Long.fromValue(object.end)).unsigned = false; + else if (typeof object.end === "string") + message.end = parseInt(object.end, 10); + else if (typeof object.end === "number") + message.end = object.end; + else if (typeof object.end === "object") + message.end = new $util.LongBits(object.end.low >>> 0, object.end.high >>> 0).toNumber(); + return message; + }; + + /** + * Creates a plain object from a Segment message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorProto.Segment + * @static + * @param {onnx.TensorProto.Segment} message Segment + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Segment.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.begin = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.begin = options.longs === String ? "0" : 0; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.end = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.end = options.longs === String ? "0" : 0; + } + if (message.begin != null && message.hasOwnProperty("begin")) + if (typeof message.begin === "number") + object.begin = options.longs === String ? String(message.begin) : message.begin; + else + object.begin = options.longs === String ? $util.Long.prototype.toString.call(message.begin) : options.longs === Number ? new $util.LongBits(message.begin.low >>> 0, message.begin.high >>> 0).toNumber() : message.begin; + if (message.end != null && message.hasOwnProperty("end")) + if (typeof message.end === "number") + object.end = options.longs === String ? String(message.end) : message.end; + else + object.end = options.longs === String ? $util.Long.prototype.toString.call(message.end) : options.longs === Number ? new $util.LongBits(message.end.low >>> 0, message.end.high >>> 0).toNumber() : message.end; + return object; + }; + + /** + * Converts this Segment to JSON. + * @function toJSON + * @memberof onnx.TensorProto.Segment + * @instance + * @returns {Object.} JSON object + */ + Segment.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Segment + * @function getTypeUrl + * @memberof onnx.TensorProto.Segment + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Segment.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorProto.Segment"; + }; + + return Segment; + })(); + + /** + * DataLocation enum. + * @name onnx.TensorProto.DataLocation + * @enum {number} + * @property {number} DEFAULT=0 DEFAULT value + * @property {number} EXTERNAL=1 EXTERNAL value + */ + TensorProto.DataLocation = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "DEFAULT"] = 0; + values[valuesById[1] = "EXTERNAL"] = 1; + return values; + })(); + + return TensorProto; + })(); + + onnx.SparseTensorProto = (function() { + + /** + * Properties of a SparseTensorProto. + * @memberof onnx + * @interface ISparseTensorProto + * @property {onnx.ITensorProto|null} [values] SparseTensorProto values + * @property {onnx.ITensorProto|null} [indices] SparseTensorProto indices + * @property {Array.|null} [dims] SparseTensorProto dims + */ + + /** + * Constructs a new SparseTensorProto. + * @memberof onnx + * @classdesc Represents a SparseTensorProto. + * @implements ISparseTensorProto + * @constructor + * @param {onnx.ISparseTensorProto=} [properties] Properties to set + */ + function SparseTensorProto(properties) { + this.dims = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * SparseTensorProto values. + * @member {onnx.ITensorProto|null|undefined} values + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.values = null; + + /** + * SparseTensorProto indices. + * @member {onnx.ITensorProto|null|undefined} indices + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.indices = null; + + /** + * SparseTensorProto dims. + * @member {Array.} dims + * @memberof onnx.SparseTensorProto + * @instance + */ + SparseTensorProto.prototype.dims = $util.emptyArray; + + /** + * Creates a new SparseTensorProto instance using the specified properties. + * @function create + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto=} [properties] Properties to set + * @returns {onnx.SparseTensorProto} SparseTensorProto instance + */ + SparseTensorProto.create = function create(properties) { + return new SparseTensorProto(properties); + }; + + /** + * Encodes the specified SparseTensorProto message. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. + * @function encode + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensorProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.values != null && Object.hasOwnProperty.call(message, "values")) + $root.onnx.TensorProto.encode(message.values, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.indices != null && Object.hasOwnProperty.call(message, "indices")) + $root.onnx.TensorProto.encode(message.indices, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + if (message.dims != null && message.dims.length) { + writer.uint32(/* id 3, wireType 2 =*/26).fork(); + for (var i = 0; i < message.dims.length; ++i) + writer.int64(message.dims[i]); + writer.ldelim(); + } + return writer; + }; + + /** + * Encodes the specified SparseTensorProto message, length delimited. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.ISparseTensorProto} message SparseTensorProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensorProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.SparseTensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.SparseTensorProto} SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensorProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.SparseTensorProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.values = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 2: { + message.indices = $root.onnx.TensorProto.decode(reader, reader.uint32()); + break; + } + case 3: { + if (!(message.dims && message.dims.length)) + message.dims = []; + if ((tag & 7) === 2) { + var end2 = reader.uint32() + reader.pos; + while (reader.pos < end2) + message.dims.push(reader.int64()); + } else + message.dims.push(reader.int64()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a SparseTensorProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.SparseTensorProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.SparseTensorProto} SparseTensorProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensorProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a SparseTensorProto message. + * @function verify + * @memberof onnx.SparseTensorProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + SparseTensorProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.values != null && message.hasOwnProperty("values")) { + var error = $root.onnx.TensorProto.verify(message.values); + if (error) + return "values." + error; + } + if (message.indices != null && message.hasOwnProperty("indices")) { + var error = $root.onnx.TensorProto.verify(message.indices); + if (error) + return "indices." + error; + } + if (message.dims != null && message.hasOwnProperty("dims")) { + if (!Array.isArray(message.dims)) + return "dims: array expected"; + for (var i = 0; i < message.dims.length; ++i) + if (!$util.isInteger(message.dims[i]) && !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high))) + return "dims: integer|Long[] expected"; + } + return null; + }; + + /** + * Creates a SparseTensorProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.SparseTensorProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.SparseTensorProto} SparseTensorProto + */ + SparseTensorProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.SparseTensorProto) + return object; + var message = new $root.onnx.SparseTensorProto(); + if (object.values != null) { + if (typeof object.values !== "object") + throw TypeError(".onnx.SparseTensorProto.values: object expected"); + message.values = $root.onnx.TensorProto.fromObject(object.values); + } + if (object.indices != null) { + if (typeof object.indices !== "object") + throw TypeError(".onnx.SparseTensorProto.indices: object expected"); + message.indices = $root.onnx.TensorProto.fromObject(object.indices); + } + if (object.dims) { + if (!Array.isArray(object.dims)) + throw TypeError(".onnx.SparseTensorProto.dims: array expected"); + message.dims = []; + for (var i = 0; i < object.dims.length; ++i) + if ($util.Long) + (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; + else if (typeof object.dims[i] === "string") + message.dims[i] = parseInt(object.dims[i], 10); + else if (typeof object.dims[i] === "number") + message.dims[i] = object.dims[i]; + else if (typeof object.dims[i] === "object") + message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); + } + return message; + }; + + /** + * Creates a plain object from a SparseTensorProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.SparseTensorProto + * @static + * @param {onnx.SparseTensorProto} message SparseTensorProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + SparseTensorProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) + object.dims = []; + if (options.defaults) { + object.values = null; + object.indices = null; + } + if (message.values != null && message.hasOwnProperty("values")) + object.values = $root.onnx.TensorProto.toObject(message.values, options); + if (message.indices != null && message.hasOwnProperty("indices")) + object.indices = $root.onnx.TensorProto.toObject(message.indices, options); + if (message.dims && message.dims.length) { + object.dims = []; + for (var j = 0; j < message.dims.length; ++j) + if (typeof message.dims[j] === "number") + object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; + else + object.dims[j] = options.longs === String ? $util.Long.prototype.toString.call(message.dims[j]) : options.longs === Number ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() : message.dims[j]; + } + return object; + }; + + /** + * Converts this SparseTensorProto to JSON. + * @function toJSON + * @memberof onnx.SparseTensorProto + * @instance + * @returns {Object.} JSON object + */ + SparseTensorProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for SparseTensorProto + * @function getTypeUrl + * @memberof onnx.SparseTensorProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + SparseTensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.SparseTensorProto"; + }; + + return SparseTensorProto; + })(); + + onnx.TensorShapeProto = (function() { + + /** + * Properties of a TensorShapeProto. + * @memberof onnx + * @interface ITensorShapeProto + * @property {Array.|null} [dim] TensorShapeProto dim + */ + + /** + * Constructs a new TensorShapeProto. + * @memberof onnx + * @classdesc Represents a TensorShapeProto. + * @implements ITensorShapeProto + * @constructor + * @param {onnx.ITensorShapeProto=} [properties] Properties to set + */ + function TensorShapeProto(properties) { + this.dim = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TensorShapeProto dim. + * @member {Array.} dim + * @memberof onnx.TensorShapeProto + * @instance + */ + TensorShapeProto.prototype.dim = $util.emptyArray; + + /** + * Creates a new TensorShapeProto instance using the specified properties. + * @function create + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto=} [properties] Properties to set + * @returns {onnx.TensorShapeProto} TensorShapeProto instance + */ + TensorShapeProto.create = function create(properties) { + return new TensorShapeProto(properties); + }; + + /** + * Encodes the specified TensorShapeProto message. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. + * @function encode + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorShapeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.dim != null && message.dim.length) + for (var i = 0; i < message.dim.length; ++i) + $root.onnx.TensorShapeProto.Dimension.encode(message.dim[i], writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TensorShapeProto message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.ITensorShapeProto} message TensorShapeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TensorShapeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorShapeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorShapeProto} TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorShapeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorShapeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (!(message.dim && message.dim.length)) + message.dim = []; + message.dim.push($root.onnx.TensorShapeProto.Dimension.decode(reader, reader.uint32())); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TensorShapeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorShapeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorShapeProto} TensorShapeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TensorShapeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TensorShapeProto message. + * @function verify + * @memberof onnx.TensorShapeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TensorShapeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.dim != null && message.hasOwnProperty("dim")) { + if (!Array.isArray(message.dim)) + return "dim: array expected"; + for (var i = 0; i < message.dim.length; ++i) { + var error = $root.onnx.TensorShapeProto.Dimension.verify(message.dim[i]); + if (error) + return "dim." + error; + } + } + return null; + }; + + /** + * Creates a TensorShapeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorShapeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorShapeProto} TensorShapeProto + */ + TensorShapeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorShapeProto) + return object; + var message = new $root.onnx.TensorShapeProto(); + if (object.dim) { + if (!Array.isArray(object.dim)) + throw TypeError(".onnx.TensorShapeProto.dim: array expected"); + message.dim = []; + for (var i = 0; i < object.dim.length; ++i) { + if (typeof object.dim[i] !== "object") + throw TypeError(".onnx.TensorShapeProto.dim: object expected"); + message.dim[i] = $root.onnx.TensorShapeProto.Dimension.fromObject(object.dim[i]); + } + } + return message; + }; + + /** + * Creates a plain object from a TensorShapeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorShapeProto + * @static + * @param {onnx.TensorShapeProto} message TensorShapeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TensorShapeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) + object.dim = []; + if (message.dim && message.dim.length) { + object.dim = []; + for (var j = 0; j < message.dim.length; ++j) + object.dim[j] = $root.onnx.TensorShapeProto.Dimension.toObject(message.dim[j], options); + } + return object; + }; + + /** + * Converts this TensorShapeProto to JSON. + * @function toJSON + * @memberof onnx.TensorShapeProto + * @instance + * @returns {Object.} JSON object + */ + TensorShapeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TensorShapeProto + * @function getTypeUrl + * @memberof onnx.TensorShapeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TensorShapeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorShapeProto"; + }; + + TensorShapeProto.Dimension = (function() { + + /** + * Properties of a Dimension. + * @memberof onnx.TensorShapeProto + * @interface IDimension + * @property {number|Long|null} [dimValue] Dimension dimValue + * @property {string|null} [dimParam] Dimension dimParam + * @property {string|null} [denotation] Dimension denotation + */ + + /** + * Constructs a new Dimension. + * @memberof onnx.TensorShapeProto + * @classdesc Represents a Dimension. + * @implements IDimension + * @constructor + * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set + */ + function Dimension(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Dimension dimValue. + * @member {number|Long|null|undefined} dimValue + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.dimValue = null; + + /** + * Dimension dimParam. + * @member {string|null|undefined} dimParam + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.dimParam = null; + + /** + * Dimension denotation. + * @member {string} denotation + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Dimension.prototype.denotation = ""; + + // OneOf field names bound to virtual getters and setters + var $oneOfFields; + + /** + * Dimension value. + * @member {"dimValue"|"dimParam"|undefined} value + * @memberof onnx.TensorShapeProto.Dimension + * @instance + */ + Object.defineProperty(Dimension.prototype, "value", { + get: $util.oneOfGetter($oneOfFields = ["dimValue", "dimParam"]), + set: $util.oneOfSetter($oneOfFields) + }); + + /** + * Creates a new Dimension instance using the specified properties. + * @function create + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension=} [properties] Properties to set + * @returns {onnx.TensorShapeProto.Dimension} Dimension instance + */ + Dimension.create = function create(properties) { + return new Dimension(properties); + }; + + /** + * Encodes the specified Dimension message. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @function encode + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Dimension.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.dimValue != null && Object.hasOwnProperty.call(message, "dimValue")) + writer.uint32(/* id 1, wireType 0 =*/8).int64(message.dimValue); + if (message.dimParam != null && Object.hasOwnProperty.call(message, "dimParam")) + writer.uint32(/* id 2, wireType 2 =*/18).string(message.dimParam); + if (message.denotation != null && Object.hasOwnProperty.call(message, "denotation")) + writer.uint32(/* id 3, wireType 2 =*/26).string(message.denotation); + return writer; + }; + + /** + * Encodes the specified Dimension message, length delimited. Does not implicitly {@link onnx.TensorShapeProto.Dimension.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.IDimension} message Dimension message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Dimension.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Dimension message from the specified reader or buffer. + * @function decode + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TensorShapeProto.Dimension} Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Dimension.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TensorShapeProto.Dimension(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.dimValue = reader.int64(); + break; + } + case 2: { + message.dimParam = reader.string(); + break; + } + case 3: { + message.denotation = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Dimension message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TensorShapeProto.Dimension} Dimension + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Dimension.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Dimension message. + * @function verify + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Dimension.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + var properties = {}; + if (message.dimValue != null && message.hasOwnProperty("dimValue")) { + properties.value = 1; + if (!$util.isInteger(message.dimValue) && !(message.dimValue && $util.isInteger(message.dimValue.low) && $util.isInteger(message.dimValue.high))) + return "dimValue: integer|Long expected"; + } + if (message.dimParam != null && message.hasOwnProperty("dimParam")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + if (!$util.isString(message.dimParam)) + return "dimParam: string expected"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + if (!$util.isString(message.denotation)) + return "denotation: string expected"; + return null; + }; + + /** + * Creates a Dimension message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {Object.} object Plain object + * @returns {onnx.TensorShapeProto.Dimension} Dimension + */ + Dimension.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TensorShapeProto.Dimension) + return object; + var message = new $root.onnx.TensorShapeProto.Dimension(); + if (object.dimValue != null) + if ($util.Long) + (message.dimValue = $util.Long.fromValue(object.dimValue)).unsigned = false; + else if (typeof object.dimValue === "string") + message.dimValue = parseInt(object.dimValue, 10); + else if (typeof object.dimValue === "number") + message.dimValue = object.dimValue; + else if (typeof object.dimValue === "object") + message.dimValue = new $util.LongBits(object.dimValue.low >>> 0, object.dimValue.high >>> 0).toNumber(); + if (object.dimParam != null) + message.dimParam = String(object.dimParam); + if (object.denotation != null) + message.denotation = String(object.denotation); + return message; + }; + + /** + * Creates a plain object from a Dimension message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {onnx.TensorShapeProto.Dimension} message Dimension + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Dimension.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.denotation = ""; + if (message.dimValue != null && message.hasOwnProperty("dimValue")) { + if (typeof message.dimValue === "number") + object.dimValue = options.longs === String ? String(message.dimValue) : message.dimValue; + else + object.dimValue = options.longs === String ? $util.Long.prototype.toString.call(message.dimValue) : options.longs === Number ? new $util.LongBits(message.dimValue.low >>> 0, message.dimValue.high >>> 0).toNumber() : message.dimValue; + if (options.oneofs) + object.value = "dimValue"; + } + if (message.dimParam != null && message.hasOwnProperty("dimParam")) { + object.dimParam = message.dimParam; + if (options.oneofs) + object.value = "dimParam"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + object.denotation = message.denotation; + return object; + }; + + /** + * Converts this Dimension to JSON. + * @function toJSON + * @memberof onnx.TensorShapeProto.Dimension + * @instance + * @returns {Object.} JSON object + */ + Dimension.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Dimension + * @function getTypeUrl + * @memberof onnx.TensorShapeProto.Dimension + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Dimension.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TensorShapeProto.Dimension"; + }; + + return Dimension; + })(); + + return TensorShapeProto; + })(); + + onnx.TypeProto = (function() { + + /** + * Properties of a TypeProto. + * @memberof onnx + * @interface ITypeProto + * @property {onnx.TypeProto.ITensor|null} [tensorType] TypeProto tensorType + * @property {onnx.TypeProto.ISequence|null} [sequenceType] TypeProto sequenceType + * @property {onnx.TypeProto.IMap|null} [mapType] TypeProto mapType + * @property {onnx.TypeProto.IOptional|null} [optionalType] TypeProto optionalType + * @property {onnx.TypeProto.ISparseTensor|null} [sparseTensorType] TypeProto sparseTensorType + * @property {string|null} [denotation] TypeProto denotation + */ + + /** + * Constructs a new TypeProto. + * @memberof onnx + * @classdesc Represents a TypeProto. + * @implements ITypeProto + * @constructor + * @param {onnx.ITypeProto=} [properties] Properties to set + */ + function TypeProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * TypeProto tensorType. + * @member {onnx.TypeProto.ITensor|null|undefined} tensorType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.tensorType = null; + + /** + * TypeProto sequenceType. + * @member {onnx.TypeProto.ISequence|null|undefined} sequenceType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.sequenceType = null; + + /** + * TypeProto mapType. + * @member {onnx.TypeProto.IMap|null|undefined} mapType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.mapType = null; + + /** + * TypeProto optionalType. + * @member {onnx.TypeProto.IOptional|null|undefined} optionalType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.optionalType = null; + + /** + * TypeProto sparseTensorType. + * @member {onnx.TypeProto.ISparseTensor|null|undefined} sparseTensorType + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.sparseTensorType = null; + + /** + * TypeProto denotation. + * @member {string} denotation + * @memberof onnx.TypeProto + * @instance + */ + TypeProto.prototype.denotation = ""; + + // OneOf field names bound to virtual getters and setters + var $oneOfFields; + + /** + * TypeProto value. + * @member {"tensorType"|"sequenceType"|"mapType"|"optionalType"|"sparseTensorType"|undefined} value + * @memberof onnx.TypeProto + * @instance + */ + Object.defineProperty(TypeProto.prototype, "value", { + get: $util.oneOfGetter($oneOfFields = ["tensorType", "sequenceType", "mapType", "optionalType", "sparseTensorType"]), + set: $util.oneOfSetter($oneOfFields) + }); + + /** + * Creates a new TypeProto instance using the specified properties. + * @function create + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto=} [properties] Properties to set + * @returns {onnx.TypeProto} TypeProto instance + */ + TypeProto.create = function create(properties) { + return new TypeProto(properties); + }; + + /** + * Encodes the specified TypeProto message. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto} message TypeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TypeProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.tensorType != null && Object.hasOwnProperty.call(message, "tensorType")) + $root.onnx.TypeProto.Tensor.encode(message.tensorType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + if (message.sequenceType != null && Object.hasOwnProperty.call(message, "sequenceType")) + $root.onnx.TypeProto.Sequence.encode(message.sequenceType, writer.uint32(/* id 4, wireType 2 =*/34).fork()).ldelim(); + if (message.mapType != null && Object.hasOwnProperty.call(message, "mapType")) + $root.onnx.TypeProto.Map.encode(message.mapType, writer.uint32(/* id 5, wireType 2 =*/42).fork()).ldelim(); + if (message.denotation != null && Object.hasOwnProperty.call(message, "denotation")) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.denotation); + if (message.sparseTensorType != null && Object.hasOwnProperty.call(message, "sparseTensorType")) + $root.onnx.TypeProto.SparseTensor.encode(message.sparseTensorType, writer.uint32(/* id 8, wireType 2 =*/66).fork()).ldelim(); + if (message.optionalType != null && Object.hasOwnProperty.call(message, "optionalType")) + $root.onnx.TypeProto.Optional.encode(message.optionalType, writer.uint32(/* id 9, wireType 2 =*/74).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified TypeProto message, length delimited. Does not implicitly {@link onnx.TypeProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto + * @static + * @param {onnx.ITypeProto} message TypeProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + TypeProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a TypeProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto} TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TypeProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.tensorType = $root.onnx.TypeProto.Tensor.decode(reader, reader.uint32()); + break; + } + case 4: { + message.sequenceType = $root.onnx.TypeProto.Sequence.decode(reader, reader.uint32()); + break; + } + case 5: { + message.mapType = $root.onnx.TypeProto.Map.decode(reader, reader.uint32()); + break; + } + case 9: { + message.optionalType = $root.onnx.TypeProto.Optional.decode(reader, reader.uint32()); + break; + } + case 8: { + message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.decode(reader, reader.uint32()); + break; + } + case 6: { + message.denotation = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a TypeProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto} TypeProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + TypeProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a TypeProto message. + * @function verify + * @memberof onnx.TypeProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + TypeProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + var properties = {}; + if (message.tensorType != null && message.hasOwnProperty("tensorType")) { + properties.value = 1; + { + var error = $root.onnx.TypeProto.Tensor.verify(message.tensorType); + if (error) + return "tensorType." + error; + } + } + if (message.sequenceType != null && message.hasOwnProperty("sequenceType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Sequence.verify(message.sequenceType); + if (error) + return "sequenceType." + error; + } + } + if (message.mapType != null && message.hasOwnProperty("mapType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Map.verify(message.mapType); + if (error) + return "mapType." + error; + } + } + if (message.optionalType != null && message.hasOwnProperty("optionalType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.Optional.verify(message.optionalType); + if (error) + return "optionalType." + error; + } + } + if (message.sparseTensorType != null && message.hasOwnProperty("sparseTensorType")) { + if (properties.value === 1) + return "value: multiple values"; + properties.value = 1; + { + var error = $root.onnx.TypeProto.SparseTensor.verify(message.sparseTensorType); + if (error) + return "sparseTensorType." + error; + } + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + if (!$util.isString(message.denotation)) + return "denotation: string expected"; + return null; + }; + + /** + * Creates a TypeProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto} TypeProto + */ + TypeProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto) + return object; + var message = new $root.onnx.TypeProto(); + if (object.tensorType != null) { + if (typeof object.tensorType !== "object") + throw TypeError(".onnx.TypeProto.tensorType: object expected"); + message.tensorType = $root.onnx.TypeProto.Tensor.fromObject(object.tensorType); + } + if (object.sequenceType != null) { + if (typeof object.sequenceType !== "object") + throw TypeError(".onnx.TypeProto.sequenceType: object expected"); + message.sequenceType = $root.onnx.TypeProto.Sequence.fromObject(object.sequenceType); + } + if (object.mapType != null) { + if (typeof object.mapType !== "object") + throw TypeError(".onnx.TypeProto.mapType: object expected"); + message.mapType = $root.onnx.TypeProto.Map.fromObject(object.mapType); + } + if (object.optionalType != null) { + if (typeof object.optionalType !== "object") + throw TypeError(".onnx.TypeProto.optionalType: object expected"); + message.optionalType = $root.onnx.TypeProto.Optional.fromObject(object.optionalType); + } + if (object.sparseTensorType != null) { + if (typeof object.sparseTensorType !== "object") + throw TypeError(".onnx.TypeProto.sparseTensorType: object expected"); + message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.fromObject(object.sparseTensorType); + } + if (object.denotation != null) + message.denotation = String(object.denotation); + return message; + }; + + /** + * Creates a plain object from a TypeProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto + * @static + * @param {onnx.TypeProto} message TypeProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + TypeProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.denotation = ""; + if (message.tensorType != null && message.hasOwnProperty("tensorType")) { + object.tensorType = $root.onnx.TypeProto.Tensor.toObject(message.tensorType, options); + if (options.oneofs) + object.value = "tensorType"; + } + if (message.sequenceType != null && message.hasOwnProperty("sequenceType")) { + object.sequenceType = $root.onnx.TypeProto.Sequence.toObject(message.sequenceType, options); + if (options.oneofs) + object.value = "sequenceType"; + } + if (message.mapType != null && message.hasOwnProperty("mapType")) { + object.mapType = $root.onnx.TypeProto.Map.toObject(message.mapType, options); + if (options.oneofs) + object.value = "mapType"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + object.denotation = message.denotation; + if (message.sparseTensorType != null && message.hasOwnProperty("sparseTensorType")) { + object.sparseTensorType = $root.onnx.TypeProto.SparseTensor.toObject(message.sparseTensorType, options); + if (options.oneofs) + object.value = "sparseTensorType"; + } + if (message.optionalType != null && message.hasOwnProperty("optionalType")) { + object.optionalType = $root.onnx.TypeProto.Optional.toObject(message.optionalType, options); + if (options.oneofs) + object.value = "optionalType"; + } + return object; + }; + + /** + * Converts this TypeProto to JSON. + * @function toJSON + * @memberof onnx.TypeProto + * @instance + * @returns {Object.} JSON object + */ + TypeProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for TypeProto + * @function getTypeUrl + * @memberof onnx.TypeProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + TypeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto"; + }; + + TypeProto.Tensor = (function() { + + /** + * Properties of a Tensor. + * @memberof onnx.TypeProto + * @interface ITensor + * @property {number|null} [elemType] Tensor elemType + * @property {onnx.ITensorShapeProto|null} [shape] Tensor shape + */ + + /** + * Constructs a new Tensor. + * @memberof onnx.TypeProto + * @classdesc Represents a Tensor. + * @implements ITensor + * @constructor + * @param {onnx.TypeProto.ITensor=} [properties] Properties to set + */ + function Tensor(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Tensor elemType. + * @member {number} elemType + * @memberof onnx.TypeProto.Tensor + * @instance + */ + Tensor.prototype.elemType = 0; + + /** + * Tensor shape. + * @member {onnx.ITensorShapeProto|null|undefined} shape + * @memberof onnx.TypeProto.Tensor + * @instance + */ + Tensor.prototype.shape = null; + + /** + * Creates a new Tensor instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor=} [properties] Properties to set + * @returns {onnx.TypeProto.Tensor} Tensor instance + */ + Tensor.create = function create(properties) { + return new Tensor(properties); + }; + + /** + * Encodes the specified Tensor message. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Tensor.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + writer.uint32(/* id 1, wireType 0 =*/8).int32(message.elemType); + if (message.shape != null && Object.hasOwnProperty.call(message, "shape")) + $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Tensor message, length delimited. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.ITensor} message Tensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Tensor.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Tensor message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Tensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Tensor} Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Tensor.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Tensor(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = reader.int32(); + break; + } + case 2: { + message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Tensor message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Tensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Tensor} Tensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Tensor.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Tensor message. + * @function verify + * @memberof onnx.TypeProto.Tensor + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Tensor.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) + if (!$util.isInteger(message.elemType)) + return "elemType: integer expected"; + if (message.shape != null && message.hasOwnProperty("shape")) { + var error = $root.onnx.TensorShapeProto.verify(message.shape); + if (error) + return "shape." + error; + } + return null; + }; + + /** + * Creates a Tensor message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Tensor + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Tensor} Tensor + */ + Tensor.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Tensor) + return object; + var message = new $root.onnx.TypeProto.Tensor(); + if (object.elemType != null) + message.elemType = object.elemType | 0; + if (object.shape != null) { + if (typeof object.shape !== "object") + throw TypeError(".onnx.TypeProto.Tensor.shape: object expected"); + message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); + } + return message; + }; + + /** + * Creates a plain object from a Tensor message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Tensor + * @static + * @param {onnx.TypeProto.Tensor} message Tensor + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Tensor.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.elemType = 0; + object.shape = null; + } + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = message.elemType; + if (message.shape != null && message.hasOwnProperty("shape")) + object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); + return object; + }; + + /** + * Converts this Tensor to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Tensor + * @instance + * @returns {Object.} JSON object + */ + Tensor.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Tensor + * @function getTypeUrl + * @memberof onnx.TypeProto.Tensor + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Tensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Tensor"; + }; + + return Tensor; + })(); + + TypeProto.Sequence = (function() { + + /** + * Properties of a Sequence. + * @memberof onnx.TypeProto + * @interface ISequence + * @property {onnx.ITypeProto|null} [elemType] Sequence elemType + */ + + /** + * Constructs a new Sequence. + * @memberof onnx.TypeProto + * @classdesc Represents a Sequence. + * @implements ISequence + * @constructor + * @param {onnx.TypeProto.ISequence=} [properties] Properties to set + */ + function Sequence(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Sequence elemType. + * @member {onnx.ITypeProto|null|undefined} elemType + * @memberof onnx.TypeProto.Sequence + * @instance + */ + Sequence.prototype.elemType = null; + + /** + * Creates a new Sequence instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence=} [properties] Properties to set + * @returns {onnx.TypeProto.Sequence} Sequence instance + */ + Sequence.create = function create(properties) { + return new Sequence(properties); + }; + + /** + * Encodes the specified Sequence message. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Sequence.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Sequence message, length delimited. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.ISequence} message Sequence message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Sequence.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Sequence message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Sequence + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Sequence} Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Sequence.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Sequence(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Sequence message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Sequence + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Sequence} Sequence + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Sequence.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Sequence message. + * @function verify + * @memberof onnx.TypeProto.Sequence + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Sequence.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) { + var error = $root.onnx.TypeProto.verify(message.elemType); + if (error) + return "elemType." + error; + } + return null; + }; + + /** + * Creates a Sequence message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Sequence + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Sequence} Sequence + */ + Sequence.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Sequence) + return object; + var message = new $root.onnx.TypeProto.Sequence(); + if (object.elemType != null) { + if (typeof object.elemType !== "object") + throw TypeError(".onnx.TypeProto.Sequence.elemType: object expected"); + message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); + } + return message; + }; + + /** + * Creates a plain object from a Sequence message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Sequence + * @static + * @param {onnx.TypeProto.Sequence} message Sequence + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Sequence.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.elemType = null; + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); + return object; + }; + + /** + * Converts this Sequence to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Sequence + * @instance + * @returns {Object.} JSON object + */ + Sequence.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Sequence + * @function getTypeUrl + * @memberof onnx.TypeProto.Sequence + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Sequence.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Sequence"; + }; + + return Sequence; + })(); + + TypeProto.Map = (function() { + + /** + * Properties of a Map. + * @memberof onnx.TypeProto + * @interface IMap + * @property {number|null} [keyType] Map keyType + * @property {onnx.ITypeProto|null} [valueType] Map valueType + */ + + /** + * Constructs a new Map. + * @memberof onnx.TypeProto + * @classdesc Represents a Map. + * @implements IMap + * @constructor + * @param {onnx.TypeProto.IMap=} [properties] Properties to set + */ + function Map(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Map keyType. + * @member {number} keyType + * @memberof onnx.TypeProto.Map + * @instance + */ + Map.prototype.keyType = 0; + + /** + * Map valueType. + * @member {onnx.ITypeProto|null|undefined} valueType + * @memberof onnx.TypeProto.Map + * @instance + */ + Map.prototype.valueType = null; + + /** + * Creates a new Map instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap=} [properties] Properties to set + * @returns {onnx.TypeProto.Map} Map instance + */ + Map.create = function create(properties) { + return new Map(properties); + }; + + /** + * Encodes the specified Map message. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap} message Map message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Map.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.keyType != null && Object.hasOwnProperty.call(message, "keyType")) + writer.uint32(/* id 1, wireType 0 =*/8).int32(message.keyType); + if (message.valueType != null && Object.hasOwnProperty.call(message, "valueType")) + $root.onnx.TypeProto.encode(message.valueType, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Map message, length delimited. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.IMap} message Map message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Map.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a Map message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Map + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Map} Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Map.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Map(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.keyType = reader.int32(); + break; + } + case 2: { + message.valueType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a Map message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Map + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Map} Map + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Map.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a Map message. + * @function verify + * @memberof onnx.TypeProto.Map + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Map.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.keyType != null && message.hasOwnProperty("keyType")) + if (!$util.isInteger(message.keyType)) + return "keyType: integer expected"; + if (message.valueType != null && message.hasOwnProperty("valueType")) { + var error = $root.onnx.TypeProto.verify(message.valueType); + if (error) + return "valueType." + error; + } + return null; + }; + + /** + * Creates a Map message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Map + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Map} Map + */ + Map.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Map) + return object; + var message = new $root.onnx.TypeProto.Map(); + if (object.keyType != null) + message.keyType = object.keyType | 0; + if (object.valueType != null) { + if (typeof object.valueType !== "object") + throw TypeError(".onnx.TypeProto.Map.valueType: object expected"); + message.valueType = $root.onnx.TypeProto.fromObject(object.valueType); + } + return message; + }; + + /** + * Creates a plain object from a Map message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Map + * @static + * @param {onnx.TypeProto.Map} message Map + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Map.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.keyType = 0; + object.valueType = null; + } + if (message.keyType != null && message.hasOwnProperty("keyType")) + object.keyType = message.keyType; + if (message.valueType != null && message.hasOwnProperty("valueType")) + object.valueType = $root.onnx.TypeProto.toObject(message.valueType, options); + return object; + }; + + /** + * Converts this Map to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Map + * @instance + * @returns {Object.} JSON object + */ + Map.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Map + * @function getTypeUrl + * @memberof onnx.TypeProto.Map + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Map.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Map"; + }; + + return Map; + })(); + + TypeProto.Optional = (function() { + + /** + * Properties of an Optional. + * @memberof onnx.TypeProto + * @interface IOptional + * @property {onnx.ITypeProto|null} [elemType] Optional elemType + */ + + /** + * Constructs a new Optional. + * @memberof onnx.TypeProto + * @classdesc Represents an Optional. + * @implements IOptional + * @constructor + * @param {onnx.TypeProto.IOptional=} [properties] Properties to set + */ + function Optional(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * Optional elemType. + * @member {onnx.ITypeProto|null|undefined} elemType + * @memberof onnx.TypeProto.Optional + * @instance + */ + Optional.prototype.elemType = null; + + /** + * Creates a new Optional instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional=} [properties] Properties to set + * @returns {onnx.TypeProto.Optional} Optional instance + */ + Optional.create = function create(properties) { + return new Optional(properties); + }; + + /** + * Encodes the specified Optional message. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Optional.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/10).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified Optional message, length delimited. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.IOptional} message Optional message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + Optional.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an Optional message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.Optional + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.Optional} Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Optional.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.Optional(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an Optional message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.Optional + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.Optional} Optional + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + Optional.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an Optional message. + * @function verify + * @memberof onnx.TypeProto.Optional + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + Optional.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) { + var error = $root.onnx.TypeProto.verify(message.elemType); + if (error) + return "elemType." + error; + } + return null; + }; + + /** + * Creates an Optional message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.Optional + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.Optional} Optional + */ + Optional.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.Optional) + return object; + var message = new $root.onnx.TypeProto.Optional(); + if (object.elemType != null) { + if (typeof object.elemType !== "object") + throw TypeError(".onnx.TypeProto.Optional.elemType: object expected"); + message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); + } + return message; + }; + + /** + * Creates a plain object from an Optional message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.Optional + * @static + * @param {onnx.TypeProto.Optional} message Optional + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + Optional.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) + object.elemType = null; + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); + return object; + }; + + /** + * Converts this Optional to JSON. + * @function toJSON + * @memberof onnx.TypeProto.Optional + * @instance + * @returns {Object.} JSON object + */ + Optional.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for Optional + * @function getTypeUrl + * @memberof onnx.TypeProto.Optional + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + Optional.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.Optional"; + }; + + return Optional; + })(); + + TypeProto.SparseTensor = (function() { + + /** + * Properties of a SparseTensor. + * @memberof onnx.TypeProto + * @interface ISparseTensor + * @property {number|null} [elemType] SparseTensor elemType + * @property {onnx.ITensorShapeProto|null} [shape] SparseTensor shape + */ + + /** + * Constructs a new SparseTensor. + * @memberof onnx.TypeProto + * @classdesc Represents a SparseTensor. + * @implements ISparseTensor + * @constructor + * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set + */ + function SparseTensor(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * SparseTensor elemType. + * @member {number} elemType + * @memberof onnx.TypeProto.SparseTensor + * @instance + */ + SparseTensor.prototype.elemType = 0; + + /** + * SparseTensor shape. + * @member {onnx.ITensorShapeProto|null|undefined} shape + * @memberof onnx.TypeProto.SparseTensor + * @instance + */ + SparseTensor.prototype.shape = null; + + /** + * Creates a new SparseTensor instance using the specified properties. + * @function create + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor=} [properties] Properties to set + * @returns {onnx.TypeProto.SparseTensor} SparseTensor instance + */ + SparseTensor.create = function create(properties) { + return new SparseTensor(properties); + }; + + /** + * Encodes the specified SparseTensor message. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. + * @function encode + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensor.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.elemType != null && Object.hasOwnProperty.call(message, "elemType")) + writer.uint32(/* id 1, wireType 0 =*/8).int32(message.elemType); + if (message.shape != null && Object.hasOwnProperty.call(message, "shape")) + $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/18).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified SparseTensor message, length delimited. Does not implicitly {@link onnx.TypeProto.SparseTensor.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.ISparseTensor} message SparseTensor message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + SparseTensor.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a SparseTensor message from the specified reader or buffer. + * @function decode + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensor.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.TypeProto.SparseTensor(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.elemType = reader.int32(); + break; + } + case 2: { + message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a SparseTensor message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + SparseTensor.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a SparseTensor message. + * @function verify + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + SparseTensor.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) + if (!$util.isInteger(message.elemType)) + return "elemType: integer expected"; + if (message.shape != null && message.hasOwnProperty("shape")) { + var error = $root.onnx.TensorShapeProto.verify(message.shape); + if (error) + return "shape." + error; + } + return null; + }; + + /** + * Creates a SparseTensor message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {Object.} object Plain object + * @returns {onnx.TypeProto.SparseTensor} SparseTensor + */ + SparseTensor.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.TypeProto.SparseTensor) + return object; + var message = new $root.onnx.TypeProto.SparseTensor(); + if (object.elemType != null) + message.elemType = object.elemType | 0; + if (object.shape != null) { + if (typeof object.shape !== "object") + throw TypeError(".onnx.TypeProto.SparseTensor.shape: object expected"); + message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); + } + return message; + }; + + /** + * Creates a plain object from a SparseTensor message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {onnx.TypeProto.SparseTensor} message SparseTensor + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + SparseTensor.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.elemType = 0; + object.shape = null; + } + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = message.elemType; + if (message.shape != null && message.hasOwnProperty("shape")) + object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); + return object; + }; + + /** + * Converts this SparseTensor to JSON. + * @function toJSON + * @memberof onnx.TypeProto.SparseTensor + * @instance + * @returns {Object.} JSON object + */ + SparseTensor.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for SparseTensor + * @function getTypeUrl + * @memberof onnx.TypeProto.SparseTensor + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + SparseTensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.TypeProto.SparseTensor"; + }; + + return SparseTensor; + })(); + + return TypeProto; + })(); + + onnx.OperatorSetIdProto = (function() { + + /** + * Properties of an OperatorSetIdProto. + * @memberof onnx + * @interface IOperatorSetIdProto + * @property {string|null} [domain] OperatorSetIdProto domain + * @property {number|Long|null} [version] OperatorSetIdProto version + */ + + /** + * Constructs a new OperatorSetIdProto. + * @memberof onnx + * @classdesc Represents an OperatorSetIdProto. + * @implements IOperatorSetIdProto + * @constructor + * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set + */ + function OperatorSetIdProto(properties) { + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * OperatorSetIdProto domain. + * @member {string} domain + * @memberof onnx.OperatorSetIdProto + * @instance + */ + OperatorSetIdProto.prototype.domain = ""; + + /** + * OperatorSetIdProto version. + * @member {number|Long} version + * @memberof onnx.OperatorSetIdProto + * @instance + */ + OperatorSetIdProto.prototype.version = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + + /** + * Creates a new OperatorSetIdProto instance using the specified properties. + * @function create + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto=} [properties] Properties to set + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto instance + */ + OperatorSetIdProto.create = function create(properties) { + return new OperatorSetIdProto(properties); + }; + + /** + * Encodes the specified OperatorSetIdProto message. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. + * @function encode + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + OperatorSetIdProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.domain); + if (message.version != null && Object.hasOwnProperty.call(message, "version")) + writer.uint32(/* id 2, wireType 0 =*/16).int64(message.version); + return writer; + }; + + /** + * Encodes the specified OperatorSetIdProto message, length delimited. Does not implicitly {@link onnx.OperatorSetIdProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.IOperatorSetIdProto} message OperatorSetIdProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + OperatorSetIdProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.OperatorSetIdProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + OperatorSetIdProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.OperatorSetIdProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.domain = reader.string(); + break; + } + case 2: { + message.version = reader.int64(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes an OperatorSetIdProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.OperatorSetIdProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + OperatorSetIdProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies an OperatorSetIdProto message. + * @function verify + * @memberof onnx.OperatorSetIdProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + OperatorSetIdProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + if (message.version != null && message.hasOwnProperty("version")) + if (!$util.isInteger(message.version) && !(message.version && $util.isInteger(message.version.low) && $util.isInteger(message.version.high))) + return "version: integer|Long expected"; + return null; + }; + + /** + * Creates an OperatorSetIdProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.OperatorSetIdProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.OperatorSetIdProto} OperatorSetIdProto + */ + OperatorSetIdProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.OperatorSetIdProto) + return object; + var message = new $root.onnx.OperatorSetIdProto(); + if (object.domain != null) + message.domain = String(object.domain); + if (object.version != null) + if ($util.Long) + (message.version = $util.Long.fromValue(object.version)).unsigned = false; + else if (typeof object.version === "string") + message.version = parseInt(object.version, 10); + else if (typeof object.version === "number") + message.version = object.version; + else if (typeof object.version === "object") + message.version = new $util.LongBits(object.version.low >>> 0, object.version.high >>> 0).toNumber(); + return message; + }; + + /** + * Creates a plain object from an OperatorSetIdProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.OperatorSetIdProto + * @static + * @param {onnx.OperatorSetIdProto} message OperatorSetIdProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + OperatorSetIdProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.defaults) { + object.domain = ""; + if ($util.Long) { + var long = new $util.Long(0, 0, false); + object.version = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; + } else + object.version = options.longs === String ? "0" : 0; + } + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.version != null && message.hasOwnProperty("version")) + if (typeof message.version === "number") + object.version = options.longs === String ? String(message.version) : message.version; + else + object.version = options.longs === String ? $util.Long.prototype.toString.call(message.version) : options.longs === Number ? new $util.LongBits(message.version.low >>> 0, message.version.high >>> 0).toNumber() : message.version; + return object; + }; + + /** + * Converts this OperatorSetIdProto to JSON. + * @function toJSON + * @memberof onnx.OperatorSetIdProto + * @instance + * @returns {Object.} JSON object + */ + OperatorSetIdProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for OperatorSetIdProto + * @function getTypeUrl + * @memberof onnx.OperatorSetIdProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + OperatorSetIdProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.OperatorSetIdProto"; + }; + + return OperatorSetIdProto; + })(); + + /** + * OperatorStatus enum. + * @name onnx.OperatorStatus + * @enum {number} + * @property {number} EXPERIMENTAL=0 EXPERIMENTAL value + * @property {number} STABLE=1 STABLE value + */ + onnx.OperatorStatus = (function() { + var valuesById = {}, values = Object.create(valuesById); + values[valuesById[0] = "EXPERIMENTAL"] = 0; + values[valuesById[1] = "STABLE"] = 1; + return values; + })(); + + onnx.FunctionProto = (function() { + + /** + * Properties of a FunctionProto. + * @memberof onnx + * @interface IFunctionProto + * @property {string|null} [name] FunctionProto name + * @property {Array.|null} [input] FunctionProto input + * @property {Array.|null} [output] FunctionProto output + * @property {Array.|null} [attribute] FunctionProto attribute + * @property {Array.|null} [attributeProto] FunctionProto attributeProto + * @property {Array.|null} [node] FunctionProto node + * @property {string|null} [docString] FunctionProto docString + * @property {Array.|null} [opsetImport] FunctionProto opsetImport + * @property {string|null} [domain] FunctionProto domain + */ + + /** + * Constructs a new FunctionProto. + * @memberof onnx + * @classdesc Represents a FunctionProto. + * @implements IFunctionProto + * @constructor + * @param {onnx.IFunctionProto=} [properties] Properties to set + */ + function FunctionProto(properties) { + this.input = []; + this.output = []; + this.attribute = []; + this.attributeProto = []; + this.node = []; + this.opsetImport = []; + if (properties) + for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; + } + + /** + * FunctionProto name. + * @member {string} name + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.name = ""; + + /** + * FunctionProto input. + * @member {Array.} input + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.input = $util.emptyArray; + + /** + * FunctionProto output. + * @member {Array.} output + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.output = $util.emptyArray; + + /** + * FunctionProto attribute. + * @member {Array.} attribute + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.attribute = $util.emptyArray; + + /** + * FunctionProto attributeProto. + * @member {Array.} attributeProto + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.attributeProto = $util.emptyArray; + + /** + * FunctionProto node. + * @member {Array.} node + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.node = $util.emptyArray; + + /** + * FunctionProto docString. + * @member {string} docString + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.docString = ""; + + /** + * FunctionProto opsetImport. + * @member {Array.} opsetImport + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.opsetImport = $util.emptyArray; + + /** + * FunctionProto domain. + * @member {string} domain + * @memberof onnx.FunctionProto + * @instance + */ + FunctionProto.prototype.domain = ""; + + /** + * Creates a new FunctionProto instance using the specified properties. + * @function create + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto=} [properties] Properties to set + * @returns {onnx.FunctionProto} FunctionProto instance + */ + FunctionProto.create = function create(properties) { + return new FunctionProto(properties); + }; + + /** + * Encodes the specified FunctionProto message. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. + * @function encode + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + FunctionProto.encode = function encode(message, writer) { + if (!writer) + writer = $Writer.create(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) + writer.uint32(/* id 1, wireType 2 =*/10).string(message.name); + if (message.input != null && message.input.length) + for (var i = 0; i < message.input.length; ++i) + writer.uint32(/* id 4, wireType 2 =*/34).string(message.input[i]); + if (message.output != null && message.output.length) + for (var i = 0; i < message.output.length; ++i) + writer.uint32(/* id 5, wireType 2 =*/42).string(message.output[i]); + if (message.attribute != null && message.attribute.length) + for (var i = 0; i < message.attribute.length; ++i) + writer.uint32(/* id 6, wireType 2 =*/50).string(message.attribute[i]); + if (message.node != null && message.node.length) + for (var i = 0; i < message.node.length; ++i) + $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 7, wireType 2 =*/58).fork()).ldelim(); + if (message.docString != null && Object.hasOwnProperty.call(message, "docString")) + writer.uint32(/* id 8, wireType 2 =*/66).string(message.docString); + if (message.opsetImport != null && message.opsetImport.length) + for (var i = 0; i < message.opsetImport.length; ++i) + $root.onnx.OperatorSetIdProto.encode(message.opsetImport[i], writer.uint32(/* id 9, wireType 2 =*/74).fork()).ldelim(); + if (message.domain != null && Object.hasOwnProperty.call(message, "domain")) + writer.uint32(/* id 10, wireType 2 =*/82).string(message.domain); + if (message.attributeProto != null && message.attributeProto.length) + for (var i = 0; i < message.attributeProto.length; ++i) + $root.onnx.AttributeProto.encode(message.attributeProto[i], writer.uint32(/* id 11, wireType 2 =*/90).fork()).ldelim(); + return writer; + }; + + /** + * Encodes the specified FunctionProto message, length delimited. Does not implicitly {@link onnx.FunctionProto.verify|verify} messages. + * @function encodeDelimited + * @memberof onnx.FunctionProto + * @static + * @param {onnx.IFunctionProto} message FunctionProto message or plain object to encode + * @param {$protobuf.Writer} [writer] Writer to encode to + * @returns {$protobuf.Writer} Writer + */ + FunctionProto.encodeDelimited = function encodeDelimited(message, writer) { + return this.encode(message, writer).ldelim(); + }; + + /** + * Decodes a FunctionProto message from the specified reader or buffer. + * @function decode + * @memberof onnx.FunctionProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @param {number} [length] Message length if known beforehand + * @returns {onnx.FunctionProto} FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + FunctionProto.decode = function decode(reader, length) { + if (!(reader instanceof $Reader)) + reader = $Reader.create(reader); + var end = length === undefined ? reader.len : reader.pos + length, message = new $root.onnx.FunctionProto(); + while (reader.pos < end) { + var tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + message.name = reader.string(); + break; + } + case 4: { + if (!(message.input && message.input.length)) + message.input = []; + message.input.push(reader.string()); + break; + } + case 5: { + if (!(message.output && message.output.length)) + message.output = []; + message.output.push(reader.string()); + break; + } + case 6: { + if (!(message.attribute && message.attribute.length)) + message.attribute = []; + message.attribute.push(reader.string()); + break; + } + case 11: { + if (!(message.attributeProto && message.attributeProto.length)) + message.attributeProto = []; + message.attributeProto.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); + break; + } + case 7: { + if (!(message.node && message.node.length)) + message.node = []; + message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); + break; + } + case 8: { + message.docString = reader.string(); + break; + } + case 9: { + if (!(message.opsetImport && message.opsetImport.length)) + message.opsetImport = []; + message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); + break; + } + case 10: { + message.domain = reader.string(); + break; + } + default: + reader.skipType(tag & 7); + break; + } + } + return message; + }; + + /** + * Decodes a FunctionProto message from the specified reader or buffer, length delimited. + * @function decodeDelimited + * @memberof onnx.FunctionProto + * @static + * @param {$protobuf.Reader|Uint8Array} reader Reader or buffer to decode from + * @returns {onnx.FunctionProto} FunctionProto + * @throws {Error} If the payload is not a reader or valid buffer + * @throws {$protobuf.util.ProtocolError} If required fields are missing + */ + FunctionProto.decodeDelimited = function decodeDelimited(reader) { + if (!(reader instanceof $Reader)) + reader = new $Reader(reader); + return this.decode(reader, reader.uint32()); + }; + + /** + * Verifies a FunctionProto message. + * @function verify + * @memberof onnx.FunctionProto + * @static + * @param {Object.} message Plain object to verify + * @returns {string|null} `null` if valid, otherwise the reason why it is not + */ + FunctionProto.verify = function verify(message) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) + return "name: string expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) + return "input: array expected"; + for (var i = 0; i < message.input.length; ++i) + if (!$util.isString(message.input[i])) + return "input: string[] expected"; + } + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) + return "output: array expected"; + for (var i = 0; i < message.output.length; ++i) + if (!$util.isString(message.output[i])) + return "output: string[] expected"; + } + if (message.attribute != null && message.hasOwnProperty("attribute")) { + if (!Array.isArray(message.attribute)) + return "attribute: array expected"; + for (var i = 0; i < message.attribute.length; ++i) + if (!$util.isString(message.attribute[i])) + return "attribute: string[] expected"; + } + if (message.attributeProto != null && message.hasOwnProperty("attributeProto")) { + if (!Array.isArray(message.attributeProto)) + return "attributeProto: array expected"; + for (var i = 0; i < message.attributeProto.length; ++i) { + var error = $root.onnx.AttributeProto.verify(message.attributeProto[i]); + if (error) + return "attributeProto." + error; + } + } + if (message.node != null && message.hasOwnProperty("node")) { + if (!Array.isArray(message.node)) + return "node: array expected"; + for (var i = 0; i < message.node.length; ++i) { + var error = $root.onnx.NodeProto.verify(message.node[i]); + if (error) + return "node." + error; + } + } + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.opsetImport != null && message.hasOwnProperty("opsetImport")) { + if (!Array.isArray(message.opsetImport)) + return "opsetImport: array expected"; + for (var i = 0; i < message.opsetImport.length; ++i) { + var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); + if (error) + return "opsetImport." + error; + } + } + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) + return "domain: string expected"; + return null; + }; + + /** + * Creates a FunctionProto message from a plain object. Also converts values to their respective internal types. + * @function fromObject + * @memberof onnx.FunctionProto + * @static + * @param {Object.} object Plain object + * @returns {onnx.FunctionProto} FunctionProto + */ + FunctionProto.fromObject = function fromObject(object) { + if (object instanceof $root.onnx.FunctionProto) + return object; + var message = new $root.onnx.FunctionProto(); + if (object.name != null) + message.name = String(object.name); + if (object.input) { + if (!Array.isArray(object.input)) + throw TypeError(".onnx.FunctionProto.input: array expected"); + message.input = []; + for (var i = 0; i < object.input.length; ++i) + message.input[i] = String(object.input[i]); + } + if (object.output) { + if (!Array.isArray(object.output)) + throw TypeError(".onnx.FunctionProto.output: array expected"); + message.output = []; + for (var i = 0; i < object.output.length; ++i) + message.output[i] = String(object.output[i]); + } + if (object.attribute) { + if (!Array.isArray(object.attribute)) + throw TypeError(".onnx.FunctionProto.attribute: array expected"); + message.attribute = []; + for (var i = 0; i < object.attribute.length; ++i) + message.attribute[i] = String(object.attribute[i]); + } + if (object.attributeProto) { + if (!Array.isArray(object.attributeProto)) + throw TypeError(".onnx.FunctionProto.attributeProto: array expected"); + message.attributeProto = []; + for (var i = 0; i < object.attributeProto.length; ++i) { + if (typeof object.attributeProto[i] !== "object") + throw TypeError(".onnx.FunctionProto.attributeProto: object expected"); + message.attributeProto[i] = $root.onnx.AttributeProto.fromObject(object.attributeProto[i]); + } + } + if (object.node) { + if (!Array.isArray(object.node)) + throw TypeError(".onnx.FunctionProto.node: array expected"); + message.node = []; + for (var i = 0; i < object.node.length; ++i) { + if (typeof object.node[i] !== "object") + throw TypeError(".onnx.FunctionProto.node: object expected"); + message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); + } + } + if (object.docString != null) + message.docString = String(object.docString); + if (object.opsetImport) { + if (!Array.isArray(object.opsetImport)) + throw TypeError(".onnx.FunctionProto.opsetImport: array expected"); + message.opsetImport = []; + for (var i = 0; i < object.opsetImport.length; ++i) { + if (typeof object.opsetImport[i] !== "object") + throw TypeError(".onnx.FunctionProto.opsetImport: object expected"); + message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); + } + } + if (object.domain != null) + message.domain = String(object.domain); + return message; + }; + + /** + * Creates a plain object from a FunctionProto message. Also converts values to other types if specified. + * @function toObject + * @memberof onnx.FunctionProto + * @static + * @param {onnx.FunctionProto} message FunctionProto + * @param {$protobuf.IConversionOptions} [options] Conversion options + * @returns {Object.} Plain object + */ + FunctionProto.toObject = function toObject(message, options) { + if (!options) + options = {}; + var object = {}; + if (options.arrays || options.defaults) { + object.input = []; + object.output = []; + object.attribute = []; + object.node = []; + object.opsetImport = []; + object.attributeProto = []; + } + if (options.defaults) { + object.name = ""; + object.docString = ""; + object.domain = ""; + } + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.input && message.input.length) { + object.input = []; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = message.input[j]; + } + if (message.output && message.output.length) { + object.output = []; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = message.output[j]; + } + if (message.attribute && message.attribute.length) { + object.attribute = []; + for (var j = 0; j < message.attribute.length; ++j) + object.attribute[j] = message.attribute[j]; + } + if (message.node && message.node.length) { + object.node = []; + for (var j = 0; j < message.node.length; ++j) + object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); + } + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.opsetImport && message.opsetImport.length) { + object.opsetImport = []; + for (var j = 0; j < message.opsetImport.length; ++j) + object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); + } + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.attributeProto && message.attributeProto.length) { + object.attributeProto = []; + for (var j = 0; j < message.attributeProto.length; ++j) + object.attributeProto[j] = $root.onnx.AttributeProto.toObject(message.attributeProto[j], options); + } + return object; + }; + + /** + * Converts this FunctionProto to JSON. + * @function toJSON + * @memberof onnx.FunctionProto + * @instance + * @returns {Object.} JSON object + */ + FunctionProto.prototype.toJSON = function toJSON() { + return this.constructor.toObject(this, $protobuf.util.toJSONOptions); + }; + + /** + * Gets the default type url for FunctionProto + * @function getTypeUrl + * @memberof onnx.FunctionProto + * @static + * @param {string} [typeUrlPrefix] your custom typeUrlPrefix(default "type.googleapis.com") + * @returns {string} The default type url + */ + FunctionProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { + if (typeUrlPrefix === undefined) { + typeUrlPrefix = "type.googleapis.com"; + } + return typeUrlPrefix + "/onnx.FunctionProto"; + }; + + return FunctionProto; + })(); + + return onnx; +})(); + +module.exports = $root; diff --git a/js/node/test/test-utils.ts b/js/node/test/test-utils.ts index 968e8a1881810..3eef90356a335 100644 --- a/js/node/test/test-utils.ts +++ b/js/node/test/test-utils.ts @@ -4,10 +4,11 @@ import assert from 'assert'; import * as fs from 'fs-extra'; import {jsonc} from 'jsonc'; -import * as onnx_proto from 'onnx-proto'; import {InferenceSession, Tensor} from 'onnxruntime-common'; import * as path from 'path'; +import * as onnx_proto from './ort-schema/protobuf/onnx'; + export const TEST_ROOT = __dirname; export const TEST_DATA_ROOT = path.join(TEST_ROOT, 'testdata'); diff --git a/js/package-lock.json b/js/package-lock.json index c87a58a3196d6..c16a8b59a3a6f 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -3391,9 +3391,9 @@ } }, "node_modules/normalize-package-data/node_modules/semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true, "bin": { "semver": "bin/semver" @@ -7011,9 +7011,9 @@ }, "dependencies": { "semver": { - "version": "5.7.1", - "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.1.tgz", - "integrity": "sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==", + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz", + "integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==", "dev": true } } diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 0ed7d887fc5e5..57219c50f39aa 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -61,7 +61,6 @@ from onnxruntime.capi.onnxruntime_inference_collection import OrtDevice # noqa: F401 from onnxruntime.capi.onnxruntime_inference_collection import OrtValue # noqa: F401 from onnxruntime.capi.onnxruntime_inference_collection import SparseTensor # noqa: F401 -from onnxruntime.capi.training import * # noqa: F403 # TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end try: # noqa: SIM105 diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 60e0b1c061a43..4a6743e9e5c52 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -169,6 +170,60 @@ Status CreateInputFeatureProvider(const std::unordered_map mlmultiarray_buffer_size) { + if (mlmultiarray_buffer == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "mlmultiarray_buffer has no data"); + } + + const size_t num_elements = array_info.count; + const auto onnx_data_type = tensor_info->data_type; + switch (onnx_data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + const auto output_data_byte_size = num_elements * sizeof(float); + ORT_RETURN_IF_NOT(!mlmultiarray_buffer_size || mlmultiarray_buffer_size == output_data_byte_size, + "CoreML output buffer size and expected output size differ"); + memcpy(tensor_buffer, mlmultiarray_buffer, output_data_byte_size); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT32: { + const auto output_data_byte_size = num_elements * sizeof(int32_t); + ORT_RETURN_IF_NOT(!mlmultiarray_buffer_size || mlmultiarray_buffer_size == output_data_byte_size, + "CoreML output buffer size and expected output size differ"); + memcpy(tensor_buffer, mlmultiarray_buffer, output_data_byte_size); + break; + } + // For this case, since Coreml Spec only uses int32 for model output while onnx provides + // int64 for model output data type. We are doing a type casting (int32 -> int64) here + // when copying the model to ORT + case ONNX_NAMESPACE::TensorProto_DataType_INT64: { + ORT_RETURN_IF_NOT(array_info.dataType == MLMultiArrayDataTypeInt32, + "CoreML output data type is not MLMultiArrayDataTypeInt32"); + ORT_RETURN_IF_NOT(!mlmultiarray_buffer_size || mlmultiarray_buffer_size == num_elements * sizeof(int32_t), + "CoreML output buffer size and expected output size differ"); + const auto model_output_span = gsl::span{static_cast(mlmultiarray_buffer), num_elements}; + const auto output_span = gsl::span{static_cast(tensor_buffer), num_elements}; + std::transform(model_output_span.begin(), model_output_span.end(), output_span.begin(), + [](int32_t v) { return static_cast(v); }); + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Output data type is not supported, actual type: ", onnx_data_type); + } + return Status::OK(); +} } // namespace NS_ASSUME_NONNULL_BEGIN @@ -298,9 +353,9 @@ - (Status)predict:(const std::unordered_map&)inputs return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "output_features has no value for ", output_name); } - auto* data = [output_value multiArrayValue]; + MLMultiArray* data = [output_value multiArrayValue]; - const auto coreml_static_output_shape = [&]() { + const auto coreml_static_output_shape = [data]() { InlinedVector result; result.reserve(data.shape.count); for (NSNumber* dim in data.shape) { @@ -324,41 +379,21 @@ - (Status)predict:(const std::unordered_map&)inputs ") do not match"); } - const void* model_output_buffer = data.dataPointer; - - if (model_output_buffer == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model_output_buffer has no data for ", output_name); - } - - const auto onnx_data_type = output_tensor_info.data_type; - switch (onnx_data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { - const auto output_data_byte_size = num_elements * sizeof(float); - memcpy(output_buffer, model_output_buffer, output_data_byte_size); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_INT32: { - const auto output_data_byte_size = num_elements * sizeof(int32_t); - memcpy(output_buffer, model_output_buffer, output_data_byte_size); - break; - } - // For this case, since Coreml Spec only uses int32 for model output while onnx provides - // int64 for model output data type. We are doing a type casting (int32 -> int64) here - // when copying the model to ORT - case ONNX_NAMESPACE::TensorProto_DataType_INT64: { - ORT_RETURN_IF_NOT(data.dataType == MLMultiArrayDataTypeInt32, - "CoreML output data type is not MLMultiArrayDataTypeInt32"); - - const auto model_output_span = gsl::span{static_cast(model_output_buffer), num_elements}; - const auto output_span = gsl::span{static_cast(output_buffer), num_elements}; - std::transform(model_output_span.begin(), model_output_span.end(), output_span.begin(), - [](int32_t v) { return static_cast(v); }); - break; - } - default: - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Output data type is not supported, actual type: ", onnx_data_type); + ORT_RETURN_IF_NOT(IsArrayContiguous(data), + "Non-contiguous output MLMultiArray is not currently supported"); + __block Status copy_status; + const auto* tensor_info = &output_tensor_info; + // `getBytesWithHandler` replaces deprecated `.dataPointer` on new versions + if (@available(macOS 12.3, iOS 15.4, *)) { + [data getBytesWithHandler:^(const void* bytes, NSInteger size) { + copy_status = CopyMLMultiArrayBuffer(bytes, output_buffer, data, tensor_info, size); + }]; + } else { + // disable size check as old API does not return buffer length + copy_status = CopyMLMultiArrayBuffer(data.dataPointer, output_buffer, data, tensor_info, std::nullopt); } + if (!copy_status.IsOK()) + return copy_status; } } } diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py deleted file mode 100644 index 4cf2e5d7f7588..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py +++ /dev/null @@ -1,1026 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import copy -import os -import unittest - -import numpy as np -import onnx -import torch -import torch.nn as nn -import torch.nn.functional as F -from helper import get_name -from numpy.testing import assert_allclose -from torchvision import datasets, transforms - -import onnxruntime -from onnxruntime.capi.ort_trainer import ( - IODescription, - LossScaler, - ModelDescription, - ORTTrainer, - generate_sample, - load_checkpoint, - save_checkpoint, -) - -SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) - - -def ort_trainer_learning_rate_description(): - return IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ) - - -def remove_extra_info(model_desc): - simple_model_desc = copy.deepcopy(model_desc) - for input_desc in simple_model_desc.inputs_: - input_desc.dtype_ = None - input_desc.num_classes_ = None - for output_desc in simple_model_desc.outputs_: - output_desc.dtype_ = None - output_desc.num_classes_ = None - return simple_model_desc - - -def bert_model_description(): - vocab_size = 30528 - input_ids_desc = IODescription( - "input_ids", - ["batch", "max_seq_len_in_batch"], - torch.int64, - num_classes=vocab_size, - ) - segment_ids_desc = IODescription("segment_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2) - input_mask_desc = IODescription("input_mask", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2) - masked_lm_labels_desc = IODescription( - "masked_lm_labels", - ["batch", "max_seq_len_in_batch"], - torch.int64, - num_classes=vocab_size, - ) - next_sentence_labels_desc = IODescription( - "next_sentence_labels", - [ - "batch", - ], - torch.int64, - num_classes=2, - ) - loss_desc = IODescription("loss", [], torch.float32) - - return ModelDescription( - [ - input_ids_desc, - segment_ids_desc, - input_mask_desc, - masked_lm_labels_desc, - next_sentence_labels_desc, - ], - [loss_desc], - ) - - -def map_optimizer_attributes(name): - no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] - no_decay = any(no_decay_key in name for no_decay_key in no_decay_keys) - if no_decay: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} - else: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6} - - -def generate_sample_batch(desc, batch_size, device): - desc_ = copy.deepcopy(desc) - desc_.shape_[0] = batch_size - sample = generate_sample(desc_, device) - return sample - - -def create_ort_trainer( - gradient_accumulation_steps, - use_mixed_precision, - allreduce_post_accumulation, - use_simple_model_desc=True, - loss_scaler=None, - deepspeed_zero_stage=0, -): - model_desc = bert_model_description() - simple_model_desc = remove_extra_info(model_desc) if use_simple_model_desc else model_desc - learning_rate_description = ort_trainer_learning_rate_description() - device = torch.device("cuda", 0) - - onnx_model = onnx.load(get_name("bert_toy_postprocessed.onnx")) - - model = ORTTrainer( - onnx_model, - None, - simple_model_desc, - "LambOptimizer", - map_optimizer_attributes, - learning_rate_description, - device, - gradient_accumulation_steps=gradient_accumulation_steps, - world_rank=0, - world_size=1, - loss_scaler=loss_scaler, - use_mixed_precision=use_mixed_precision, - allreduce_post_accumulation=allreduce_post_accumulation, - deepspeed_zero_stage=deepspeed_zero_stage, - ) - - return model, model_desc, device - - -def run_bert_training_test( - gradient_accumulation_steps, - use_mixed_precision, - allreduce_post_accumulation, - use_simple_model_desc=True, - use_internel_loss_scale=False, -): - torch.manual_seed(1) - onnxruntime.set_seed(1) - - loss_scaler = LossScaler("ort_test_input_loss_scalar", True) if use_internel_loss_scale else None - - model, model_desc, device = create_ort_trainer( - gradient_accumulation_steps, - use_mixed_precision, - allreduce_post_accumulation, - use_simple_model_desc, - loss_scaler, - ) - - if loss_scaler is None: - loss_scaler = LossScaler(model.loss_scale_input_name, True) - - input_ids_batches = [] - segment_ids_batches = [] - input_mask_batches = [] - masked_lm_labels_batches = [] - next_sentence_labels_batches = [] - batch_size = 16 - num_batches = 8 - for _batch in range(num_batches): - input_ids_batches = [ - *input_ids_batches, - generate_sample_batch(model_desc.inputs_[0], batch_size, device), - ] - segment_ids_batches = [ - *segment_ids_batches, - generate_sample_batch(model_desc.inputs_[1], batch_size, device), - ] - input_mask_batches = [ - *input_mask_batches, - generate_sample_batch(model_desc.inputs_[2], batch_size, device), - ] - masked_lm_labels_batches = [ - *masked_lm_labels_batches, - generate_sample_batch(model_desc.inputs_[3], batch_size, device), - ] - next_sentence_labels_batches = [ - *next_sentence_labels_batches, - generate_sample_batch(model_desc.inputs_[4], batch_size, device), - ] - - lr_batch_list = [ - 0.0000000e00, - 4.6012269e-07, - 9.2024538e-07, - 1.3803681e-06, - 1.8404908e-06, - 2.3006135e-06, - 2.7607362e-06, - 3.2208588e-06, - 3.6809815e-06, - ] - - actual_losses = [] - actual_all_finites = [] - - for batch_count in range(num_batches): - input_ids = generate_sample_batch(model_desc.inputs_[0], batch_size, device) - segment_ids = generate_sample_batch(model_desc.inputs_[1], batch_size, device) - input_mask = generate_sample_batch(model_desc.inputs_[2], batch_size, device) - masked_lm_labels = generate_sample_batch(model_desc.inputs_[3], batch_size, device) - next_sentence_labels = generate_sample_batch(model_desc.inputs_[4], batch_size, device) - lr = lr_batch_list[batch_count] - - learning_rate = torch.tensor([lr]).to(device) - training_args = [ - input_ids, - segment_ids, - input_mask, - masked_lm_labels, - next_sentence_labels, - learning_rate, - ] - if use_mixed_precision: - if not use_internel_loss_scale: - loss_scale = torch.tensor([loss_scaler.loss_scale_]).to(device) - training_args.append(loss_scale) - actual_loss = model.train_step(*training_args) - if isinstance(actual_loss, (list, tuple)): - assert len(actual_loss) == 2 - actual_loss, actual_all_finite = actual_loss - if not use_internel_loss_scale: - loss_scaler.update_loss_scale(actual_all_finite.item()) - actual_all_finites = [ - *actual_all_finites, - actual_all_finite.cpu().numpy().item(0), - ] - - actual_losses = [*actual_losses, actual_loss.cpu().numpy().item(0)] - else: - loss = model(*training_args) - actual_losses = [*actual_losses, loss.cpu().numpy().item(0)] - - if batch_count == num_batches - 1: - # test eval_step api with fetches at the end of the training. - # if eval_step is called during the training, it will affect the actual training loss (training session is stateful). - eval_loss = model.eval_step( - input_ids, - segment_ids, - input_mask, - masked_lm_labels, - next_sentence_labels, - fetches=["loss"], - ) - eval_loss = eval_loss.cpu().numpy().item(0) - - # If using internal loss scale, all_finites are handled internally too. - if use_mixed_precision and not use_internel_loss_scale: - return actual_losses, actual_all_finites, eval_loss - else: - return actual_losses, eval_loss - - -class MNISTWrapper: - class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - self.register_buffer("bias_buffer", torch.tensor(1e-6)) - - def forward(self, x): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - out = torch.add(out, self.bias_buffer.to(out.dtype)) - return out - - class NeuralNetWithLoss(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, x, target): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - return F.nll_loss(F.log_softmax(out, dim=1), target), out - - def my_loss(x, target): # noqa: N805 - return F.nll_loss(F.log_softmax(x, dim=1), target) - - def train_with_trainer(self, learningRate, trainer, device, train_loader, epoch): - actual_losses = [] - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - args_log_interval = 100 - if batch_idx % args_log_interval == 0: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, - batch_idx * len(data), - len(train_loader.dataset), - 100.0 * batch_idx / len(train_loader), - loss.item(), - ) - ) - actual_losses = [*actual_losses, loss.cpu().numpy().item()] - - return actual_losses - - # TODO: comple this once ORT training can do evaluation. - def test_with_trainer(self, trainer, device, test_loader): - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - output = F.log_softmax(trainer.eval_step((data), fetches=["probability"]), dim=1) - test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print( - "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( - test_loss, - correct, - len(test_loader.dataset), - 100.0 * correct / len(test_loader.dataset), - ) - ) - - return test_loss, correct / len(test_loader.dataset) - - def mnist_model_description(): - input_desc = IODescription("input1", ["batch", 784], torch.float32) - label_desc = IODescription( - "label", - [ - "batch", - ], - torch.int64, - num_classes=10, - ) - loss_desc = IODescription("loss", [], torch.float32) - probability_desc = IODescription("probability", ["batch", 10], torch.float32) - return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc]) - - def get_loaders(self): - args_batch_size = 64 - args_test_batch_size = 1000 - - kwargs = {"num_workers": 0, "pin_memory": True} - # set shuffle to False to get deterministic data set among different torch version - train_loader = torch.utils.data.DataLoader( - datasets.MNIST( - os.path.join(SCRIPT_DIR, "data"), - train=True, - download=True, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args_batch_size, - shuffle=False, - **kwargs, - ) - test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - os.path.join(SCRIPT_DIR, "data"), - train=False, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args_test_batch_size, - shuffle=False, - **kwargs, - ) - - return train_loader, test_loader - - def get_model(self): - input_size = 784 - hidden_size = 500 - num_classes = 10 - - # warning: changes the pytorch random generator state - model = MNISTWrapper.NeuralNet(input_size, hidden_size, num_classes) - model_desc = MNISTWrapper.mnist_model_description() - return model, model_desc - - def get_model_with_internal_loss(self): - input_size = 784 - hidden_size = 500 - num_classes = 10 - - # warning: changes the pytorch random generator state - model = MNISTWrapper.NeuralNetWithLoss(input_size, hidden_size, num_classes) - model_desc = MNISTWrapper.mnist_model_description() - return model, model_desc - - def get_trainer( - self, - model, - model_desc, - device, - onnx_opset_ver=12, - frozen_weights=[], # noqa: B006 - internal_loss_fn=False, - get_lr_this_step=None, - optimizer="SGDOptimizer", - ): - loss_fn = MNISTWrapper.my_loss if not internal_loss_fn else None - return ORTTrainer( - model, - loss_fn, - model_desc, - optimizer, - None, - IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ), - device, - _opset_version=onnx_opset_ver, - frozen_weights=frozen_weights, - get_lr_this_step=get_lr_this_step, - ) - - -class TestOrtTrainer(unittest.TestCase): - def run_mnist_training_and_testing(onnx_opset_ver): # noqa: N805 - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - trainer = mnist.get_trainer(model, model_desc, device, onnx_opset_ver=onnx_opset_ver) - - learningRate = 0.01 # noqa: N806 - args_epochs = 2 - expected_losses = [ - 2.312044143676758, - 0.8018650412559509, - 0.5819257497787476, - 0.47025489807128906, - 0.35800155997276306, - 0.41124576330184937, - 0.2731882333755493, - 0.4201386570930481, - 0.39458805322647095, - 0.38380366563796997, - 0.2722422480583191, - 0.24230478703975677, - 0.23505745828151703, - 0.33442264795303345, - 0.21140924096107483, - 0.31545233726501465, - 0.18556523323059082, - 0.3453553020954132, - 0.29598352313041687, - 0.3595045208930969, - ] - - expected_test_losses = [0.3145490005493164, 0.256188737487793] - expected_test_accuracies = [0.9075, 0.9265] - - actual_losses = [] - actual_test_losses, actual_accuracies = [], [] - for epoch in range(1, args_epochs + 1): - actual_losses = [ - *actual_losses, - *mnist.train_with_trainer(learningRate, trainer, device, train_loader, epoch), - ] - - test_loss, accuracy = mnist.test_with_trainer(trainer, device, test_loader) - actual_test_losses = [*actual_test_losses, test_loss] - actual_accuracies = [*actual_accuracies, accuracy] - - # if you update outcomes, also do so for resume from checkpoint test - # args_checkpoint_epoch = 1 - # if epoch == args_checkpoint_epoch: - # state = {'rng_state': torch.get_rng_state(), 'model': trainer.state_dict()} - # torch.save(state, get_name("ckpt_mnist.pt")) - - print("actual_losses=", actual_losses) - print("actual_test_losses=", actual_test_losses) - print("actual_accuracies=", actual_accuracies) - - # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs - # import pdb; pdb.set_trace() - rtol = 1e-03 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_test_losses, - actual_test_losses, - rtol=rtol, - err_msg="test loss mismatch", - ) - assert_allclose( - expected_test_accuracies, - actual_accuracies, - rtol=rtol, - err_msg="test accuracy mismatch", - ) - - def test_mnist_training_and_testing_opset12(self): - TestOrtTrainer.run_mnist_training_and_testing(onnx_opset_ver=12) - - def test_mnist_resume_training_and_testing(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - learningRate = 0.01 # noqa: N806 - args_epochs = 2 - args_checkpoint_epoch = 1 - # should match those in test without checkpointing - expected_losses = [ - 0.26509523391723633, - 0.24135658144950867, - 0.2397943139076233, - 0.3351520597934723, - 0.20998981595039368, - 0.31488314270973206, - 0.18481917679309845, - 0.34727591276168823, - 0.2971782684326172, - 0.3609251379966736, - ] - - expected_test_losses = [0.25632242965698243] - expected_test_accuracies = [0.9264] - - actual_losses = [] - actual_test_losses, actual_accuracies = [], [] - - # restore from checkpoint - resume_trainer = mnist.get_trainer(model, model_desc, device) - checkpoint = torch.load(get_name("ckpt_mnist.pt"), map_location="cpu") - torch.set_rng_state(checkpoint["rng_state"]) - resume_trainer.load_state_dict(checkpoint["model"], strict=True) - - # continue .. - for epoch in range(args_checkpoint_epoch + 1, args_epochs + 1): - actual_losses = [ - *actual_losses, - *mnist.train_with_trainer(learningRate, resume_trainer, device, train_loader, epoch), - ] - - test_loss, accuracy = mnist.test_with_trainer(resume_trainer, device, test_loader) - actual_test_losses = [*actual_test_losses, test_loss] - actual_accuracies = [*actual_accuracies, accuracy] - - print("actual_losses=", actual_losses) - print("actual_test_losses=", actual_test_losses) - print("actual_accuracies=", actual_accuracies) - - # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs - # import pdb; pdb.set_trace() - rtol = 1e-03 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_test_losses, - actual_test_losses, - rtol=rtol, - err_msg="test loss mismatch", - ) - assert_allclose( - expected_test_accuracies, - actual_accuracies, - rtol=rtol, - err_msg="test accuracy mismatch", - ) - - def test_mnist_state_dict(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device) - state_dict = trainer.state_dict() - assert state_dict == {} - - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - state_dict = trainer.state_dict() - assert state_dict.keys() == { - "fc1.bias", - "fc1.weight", - "fc2.bias", - "fc2.weight", - "bias_buffer", - } - - def test_mnist_save_as_onnx(self): - torch.manual_seed(1) - device = torch.device("cuda") - onnx_file_name = "mnist.onnx" - if os.path.exists(onnx_file_name): - os.remove(onnx_file_name) - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device) - trainer.save_as_onnx(onnx_file_name) - assert not os.path.exists(onnx_file_name) - - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - trainer.save_as_onnx(onnx_file_name) - assert os.path.exists(onnx_file_name) - - def test_mnist_device(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - for model_device in [torch.device("cpu"), torch.device("cuda")]: - model.to(model_device) - trainer = mnist.get_trainer(model, model_desc, device) - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - def test_mnist_initializer_names(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device) - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - assert ({n.name for n in trainer.onnx_model_.graph.initializer} - {"bias_buffer"}) == { - n for n, t in model.named_parameters() - } - - def test_mnist_initializer_names_with_internal_loss(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model_with_internal_loss() - - def get_lr_this_step(global_step): - learningRate = 0.02 # noqa: N806 - return torch.tensor([learningRate]) - - trainer = mnist.get_trainer( - model, - model_desc, - device, - internal_loss_fn=True, - get_lr_this_step=get_lr_this_step, - ) - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target) - - assert {n.name for n in trainer.onnx_model_.graph.initializer} == {n for n, t in model.named_parameters()} - - def test_mnist_frozen_weight(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device, frozen_weights=["fc1.weight"]) - - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - fc1_trainstep_1 = trainer.state_dict()["fc1.weight"] - fc2_trainstep_1 = trainer.state_dict()["fc2.weight"] - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - fc1_trainstep_2 = trainer.state_dict()["fc1.weight"] - fc2_trainstep_2 = trainer.state_dict()["fc2.weight"] - assert np.array_equal(fc1_trainstep_1, fc1_trainstep_2) and not np.array_equal(fc2_trainstep_1, fc2_trainstep_2) - - def test_mnist_torch_buffer(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device) - - learningRate = 0.02 # noqa: N806 - - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - fc1_trainstep_1 = trainer.state_dict()["fc1.weight"] - bias_buffer_trainstep_1 = trainer.state_dict()["bias_buffer"] - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - fc1_trainstep_2 = trainer.state_dict()["fc1.weight"] - bias_buffer_trainstep_2 = trainer.state_dict()["bias_buffer"] - assert not np.array_equal(fc1_trainstep_1, fc1_trainstep_2) and np.array_equal( - bias_buffer_trainstep_1, bias_buffer_trainstep_2 - ) - - def test_mnist_frozen_weight_checkpoint(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer(model, model_desc, device, frozen_weights=["fc1.weight"]) - - learningRate = 0.02 # noqa: N806 - - # do one train step - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - # do one eval step - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.eval_step(data, target) - - # save checkpoint, load model and compare - state_dict = trainer.state_dict() - - new_model, _ = mnist.get_model() - trainer = mnist.get_trainer(new_model, model_desc, device, frozen_weights=["fc1.weight"]) - trainer.load_state_dict(state_dict) - - ckpt_loss, _ = trainer.eval_step(data, target) - assert loss == ckpt_loss - - loaded_state_dict = trainer.state_dict() - assert state_dict.keys() == loaded_state_dict.keys() - - def test_mnist_training_checkpoint(self): - torch.manual_seed(1) - device = torch.device("cuda") - - mnist = MNISTWrapper() - train_loader, test_loader = mnist.get_loaders() - model, model_desc = mnist.get_model() - - trainer = mnist.get_trainer( - model, - model_desc, - device, - optimizer="LambOptimizer", - frozen_weights=["fc1.weight"], - ) - - learningRate = 0.02 # noqa: N806 - - # do 5 train step - for _i in range(5): - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - - # do one eval step - data, target = next(iter(train_loader)) - data, target = data.to(device), target.to(device) - data = data.reshape(data.shape[0], -1) - - loss, _ = trainer.eval_step(data, target) - - # save checkpoint, load model and compare - state_dict = trainer.state_dict() - - new_model, _ = mnist.get_model() - trainer = mnist.get_trainer( - new_model, - model_desc, - device, - optimizer="LambOptimizer", - frozen_weights=["fc1.weight"], - ) - trainer.load_state_dict(state_dict) - - ckpt_loss, _ = trainer.eval_step(data, target) - assert loss == ckpt_loss - - loaded_state_dict = trainer.state_dict() - assert state_dict.keys() == loaded_state_dict.keys() - for key in state_dict: - assert np.array_equal(state_dict[key], loaded_state_dict[key]) - - def test_bert_training_basic(self): - expected_losses = [ - 11.027887, - 11.108191, - 11.055356, - 11.040912, - 10.960277, - 11.02691, - 11.082471, - 10.920979, - ] - expected_eval_loss = [10.958977] - actual_losses, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=1, - use_mixed_precision=False, - allreduce_post_accumulation=False, - ) - - # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs - # print('losses expected: ', expected_losses) - # print('losses actual: ', actual_losses) - # print('eval_loss expected: ', expected_eval_loss) - # print('eval_loss actual: ', actual_eval_loss) - # import pdb; pdb.set_trace() - - rtol = 1e-03 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - def test_bert_training_gradient_accumulation(self): - expected_losses = [ - 11.027887, - 11.108191, - 11.055354, - 11.040904, - 10.960266, - 11.026897, - 11.082475, - 10.920998, - ] - expected_eval_loss = [10.958998] - - actual_losses, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=4, - use_mixed_precision=False, - allreduce_post_accumulation=False, - ) - - # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs - # print('losses expected: ', expected_losses) - # print('losses actual: ', actual_losses) - # print('eval_loss expected: ', expected_eval_loss) - # print('eval_loss actual: ', actual_eval_loss) - # import pdb; pdb.set_trace() - - rtol = 1e-03 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - def test_bert_checkpointing_basic(self): - model, _, _ = create_ort_trainer( - gradient_accumulation_steps=1, - use_mixed_precision=False, - allreduce_post_accumulation=True, - use_simple_model_desc=True, - loss_scaler=None, - ) - sd = model.state_dict() - - # modify one of the default values - sd["bert.encoder.layer.0.attention.output.LayerNorm.weight"] += 1 - model.load_state_dict(sd) - - ckpt_dir = "testdata" - save_checkpoint(model, ckpt_dir, "bert_toy_save_test") - del model - - # create new model - model2, _, _ = create_ort_trainer( - gradient_accumulation_steps=1, - use_mixed_precision=False, - allreduce_post_accumulation=True, - use_simple_model_desc=True, - loss_scaler=None, - ) - - # load changed checkpoint - load_checkpoint(model2, ckpt_dir, "bert_toy_save_test") - loaded_sd = model2.state_dict() - - for k, v in loaded_sd.items(): - assert torch.all(torch.eq(v, sd[k])) - - def test_wrap_model_loss_fn_state_dict(self): - torch.manual_seed(1) - device = torch.device("cuda") - - class LinearModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 4) - - def forward(self, y=None, x=None): - if y is not None: - return self.linear(x) + y - else: - return self.linear(x) + torch.ones(2, 4) - - pt_model = LinearModel() - data = torch.randn(2, 2) - label = torch.tensor([0, 1], dtype=torch.int64) - input_desc = IODescription("x", [2, 2], torch.float32) - label_desc = IODescription( - "label", - [ - 2, - ], - torch.int64, - num_classes=4, - ) - output_desc = IODescription("output", [2, 4], torch.float32) - loss_desc = IODescription("loss", [], torch.float32) - model_desc = ModelDescription([input_desc, label_desc], [loss_desc, output_desc]) - - def loss_fn(x, label): - return F.nll_loss(F.log_softmax(x, dim=1), label) - - def get_lr_this_step(global_step): - learningRate = 0.02 # noqa: N806 - return torch.tensor([learningRate]) - - ort_trainer = ORTTrainer( - pt_model, - loss_fn, - model_desc, - "SGDOptimizer", - None, - IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ), - device, - get_lr_this_step=get_lr_this_step, - ) - ort_trainer.train_step(x=data, label=label) - state_dict = ort_trainer.state_dict() - assert state_dict.keys() == {"linear.bias", "linear.weight"} - - -if __name__ == "__main__": - unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py deleted file mode 100644 index 3b994e6f26710..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import unittest - -from numpy.testing import assert_allclose, assert_array_equal -from onnxruntime_test_ort_trainer import run_bert_training_test - - -class TestOrtTrainer(unittest.TestCase): - def test_bert_training_mixed_precision(self): - expected_losses = [ - 11.034248352050781, - 11.125300407409668, - 11.006105422973633, - 11.047048568725586, - 11.027417182922363, - 11.015759468078613, - 11.060905456542969, - 10.971782684326172, - ] - expected_all_finites = [True, True, True, True, True, True, True, True] - expected_eval_loss = [10.959012985229492] - actual_losses, actual_all_finites, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=1, - use_mixed_precision=True, - allreduce_post_accumulation=False, - use_simple_model_desc=False, - ) - - rtol = 1e-02 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - def test_bert_training_mixed_precision_internal_loss_scale(self): - expected_losses = [ - 11.034248352050781, - 11.125300407409668, - 11.006105422973633, - 11.047048568725586, - 11.027417182922363, - 11.015759468078613, - 11.060905456542969, - 10.971782684326172, - ] - expected_eval_loss = [10.959012985229492] - actual_losses, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=1, - use_mixed_precision=True, - allreduce_post_accumulation=False, - use_simple_model_desc=False, - use_internel_loss_scale=True, - ) - - rtol = 1e-02 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - def test_bert_training_gradient_accumulation_mixed_precision(self): - expected_losses = [ - 11.034248352050781, - 11.125300407409668, - 11.006077766418457, - 11.047025680541992, - 11.027434349060059, - 11.0156831741333, - 11.060973167419434, - 10.971841812133789, - ] - expected_all_finites = [True, True] - expected_eval_loss = [10.95903205871582] - actual_losses, actual_all_finites, actual_eval_loss = run_bert_training_test( - gradient_accumulation_steps=4, - use_mixed_precision=True, - allreduce_post_accumulation=False, - use_simple_model_desc=False, - ) - - rtol = 1e-02 - assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch") - assert_allclose( - expected_eval_loss, - actual_eval_loss, - rtol=rtol, - err_msg="evaluation loss mismatch", - ) - - -if __name__ == "__main__": - unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py b/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py deleted file mode 100644 index 540f39b797bdb..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import unittest - -import torch -import torch.nn as nn -from numpy.testing import assert_allclose -from onnxruntime_test_ort_trainer import map_optimizer_attributes, ort_trainer_learning_rate_description -from onnxruntime_test_training_unittest_utils import process_dropout - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer - - -class TestTrainingDropout(unittest.TestCase): - def setUp(self): - torch.manual_seed(1) - onnxruntime.set_seed(1) - - @unittest.skip( - "Temporarily disable this test. The graph below will trigger ORT to " - "sort backward graph before forward graph which gives incorrect result. " - "https://github.com/microsoft/onnxruntime/issues/16801" - ) - def test_training_and_eval_dropout(self): - class TwoDropoutNet(nn.Module): - def __init__(self, drop_prb_1, drop_prb_2, dim_size): - super().__init__() - self.drop_1 = nn.Dropout(drop_prb_1) - self.drop_2 = nn.Dropout(drop_prb_2) - self.weight_1 = torch.nn.Parameter(torch.zeros(dim_size, dtype=torch.float32)) - - def forward(self, x): - x = x + self.weight_1 - x = self.drop_1(x) - x = self.drop_2(x) - output = x - return output[0] - - dim_size = 3 - device = torch.device("cuda", 0) - # This will drop all values, therefore expecting all 0 in output tensor - model = TwoDropoutNet(0.999, 0.999, dim_size) - input_desc = IODescription("input", [dim_size], torch.float32) - output_desc = IODescription("output", [], torch.float32) - model_desc = ModelDescription([input_desc], [output_desc]) - lr_desc = ort_trainer_learning_rate_description() - model = ORTTrainer( - model, - None, - model_desc, - "LambOptimizer", - map_optimizer_attributes, - lr_desc, - device, - postprocess_model=process_dropout, - world_rank=0, - world_size=1, - ) - input = torch.ones(dim_size, dtype=torch.float32).to(device) - expected_training_output = [0.0] - expected_eval_output = [1.0] - learning_rate = torch.tensor([1.0000000e00]).to(device) - input_args = [input, learning_rate] - train_output = model.train_step(*input_args) - - rtol = 1e-04 - assert_allclose( - expected_training_output, - train_output.item(), - rtol=rtol, - err_msg="dropout training loss mismatch", - ) - - eval_output = model.eval_step(input) - assert_allclose( - expected_eval_output, - eval_output.item(), - rtol=rtol, - err_msg="dropout eval loss mismatch", - ) - - # Do another train step to make sure it's using original ratios - train_output_2 = model.train_step(*input_args) - assert_allclose( - expected_training_output, - train_output_2.item(), - rtol=rtol, - err_msg="dropout training loss 2 mismatch", - ) - - -if __name__ == "__main__": - unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py b/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py deleted file mode 100644 index 3d3feca06a99b..0000000000000 --- a/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py +++ /dev/null @@ -1,56 +0,0 @@ -import numpy as np -from onnx import numpy_helper - - -def get_node_index(model, node): - i = 0 - while i < len(model.graph.node): - if model.graph.node[i] == node: - break - i += 1 - return i if i < len(model.graph.node) else None - - -def add_const(model, name, output, t_value=None, f_value=None): - const_node = model.graph.node.add() - const_node.op_type = "Constant" - const_node.name = name - const_node.output.extend([output]) - attr = const_node.attribute.add() - attr.name = "value" - if t_value is not None: - attr.type = 4 - attr.t.CopyFrom(t_value) - else: - attr.type = 1 - attr.f = f_value - return const_node - - -def process_dropout(model): - dropouts = [] - index = 0 - for node in model.graph.node: - if node.op_type == "Dropout": - new_dropout = model.graph.node.add() - new_dropout.op_type = "TrainableDropout" - new_dropout.name = "TrainableDropout_%d" % index - # make ratio node - ratio = np.asarray([node.attribute[0].f], dtype=np.float32) - print(ratio.shape) - ratio_value = numpy_helper.from_array(ratio) - ratio_node = add_const( - model, - "dropout_node_ratio_%d" % index, - "dropout_node_ratio_%d" % index, - t_value=ratio_value, - ) - print(ratio_node) - new_dropout.input.extend([node.input[0], ratio_node.output[0]]) - new_dropout.output.extend(node.output) - dropouts.append(get_node_index(model, node)) - index += 1 - dropouts.sort(reverse=True) - for d in dropouts: - del model.graph.node[d] - model.opset_import[0].version = 10 diff --git a/orttraining/orttraining/python/checkpointing_utils.py b/orttraining/orttraining/python/checkpointing_utils.py deleted file mode 100644 index 460b9982297d1..0000000000000 --- a/orttraining/orttraining/python/checkpointing_utils.py +++ /dev/null @@ -1,127 +0,0 @@ -import os - -import torch - - -def list_checkpoint_files(checkpoint_dir, checkpoint_prefix, extension=".ort.pt"): - ckpt_file_names = [f for f in os.listdir(checkpoint_dir) if f.startswith(checkpoint_prefix)] - ckpt_file_names = [f for f in ckpt_file_names if f.endswith(extension)] - ckpt_file_names = [os.path.join(checkpoint_dir, f) for f in ckpt_file_names] - - assert len(ckpt_file_names) > 0, 'No checkpoint files found with prefix "{}" in directory {}.'.format( - checkpoint_prefix, checkpoint_dir - ) - return ckpt_file_names - - -def get_checkpoint_name(prefix, is_partitioned, world_rank=None, world_size=None): - SINGLE_CHECKPOINT_FILENAME = "{prefix}.ort.pt" # noqa: N806 - MULTIPLE_CHECKPOINT_FILENAME = "{prefix}.ZeRO.{world_rank}.{world_size}.ort.pt" # noqa: N806 - - if is_partitioned: - filename = MULTIPLE_CHECKPOINT_FILENAME.format( - prefix=prefix, world_rank=world_rank, world_size=(world_size - 1) - ) - else: - filename = SINGLE_CHECKPOINT_FILENAME.format(prefix=prefix) - - return filename - - -def _split_state_dict(state_dict): - optimizer_keys = ["Moment_1_", "Moment_2_", "Update_Count_", "Step"] - split_sd = {"optimizer": {}, "fp32_param": {}, "fp16_param": {}} - for k, v in state_dict.items(): - mode = "fp32_param" - for optim_key in optimizer_keys: - if k.startswith(optim_key): - mode = "optimizer" - break - if k.endswith("_fp16"): - mode = "fp16_param" - split_sd[mode][k] = v - return split_sd - - -class CombineZeroCheckpoint: - def __init__(self, checkpoint_files, clean_state_dict=None): - assert len(checkpoint_files) > 0, "No checkpoint files passed" - self.checkpoint_files = checkpoint_files - self.clean_state_dict = clean_state_dict - self.world_size = int(self.checkpoint_files[0].split("ZeRO")[1].split(".")[2]) + 1 - assert len(self.checkpoint_files) == self.world_size, f"Could not find {self.world_size} files" - self.weight_shape_map = dict() - self.sharded_params = set() - - def _split_name(self, name: str): - name_split = name.split("_view_") - view_num = None - if len(name_split) > 1: - view_num = int(name_split[1]) - optimizer_key = "" - mp_suffix = "" - if name_split[0].startswith("Moment_1"): - optimizer_key = "Moment_1_" - elif name_split[0].startswith("Moment_2"): - optimizer_key = "Moment_2_" - elif name_split[0].startswith("Update_Count"): - optimizer_key = "Update_Count_" - elif name_split[0].endswith("_fp16"): - mp_suffix = "_fp16" - param_name = name_split[0] - if optimizer_key: - param_name = param_name.split(optimizer_key)[1] - param_name = param_name.split("_fp16")[0] - return param_name, optimizer_key, view_num, mp_suffix - - def _update_weight_statistics(self, name, value): - if name not in self.weight_shape_map: - self.weight_shape_map[name] = value.size() # original shape of tensor - - def _reshape_tensor(self, key): - value = self.aggregate_state_dict[key] - weight_name, _, _, _ = self._split_name(key) - set_size = self.weight_shape_map[weight_name] - self.aggregate_state_dict[key] = value.reshape(set_size) - - def _aggregate(self, param_dict): - for k, v in param_dict.items(): - weight_name, optimizer_key, view_num, mp_suffix = self._split_name(k) - if view_num is not None: - # parameter is sharded - param_name = optimizer_key + weight_name + mp_suffix - - if param_name in self.aggregate_state_dict and optimizer_key not in ["Update_Count_"]: - self.sharded_params.add(param_name) - # Found a previous shard of the param, concatenate shards ordered by ranks - self.aggregate_state_dict[param_name] = torch.cat((self.aggregate_state_dict[param_name], v)) - else: - self.aggregate_state_dict[param_name] = v - else: - if k in self.aggregate_state_dict: - assert (self.aggregate_state_dict[k] == v).all(), "Unsharded params must have the same value" - else: - self.aggregate_state_dict[k] = v - self._update_weight_statistics(weight_name, v) - - def aggregate_checkpoints(self): - checkpoint_prefix = self.checkpoint_files[0].split(".ZeRO")[0] - self.aggregate_state_dict = dict() - - for i in range(self.world_size): - checkpoint_name = get_checkpoint_name(checkpoint_prefix, True, i, self.world_size) - rank_state_dict = torch.load(checkpoint_name, map_location=torch.device("cpu")) - if "model" in rank_state_dict: - rank_state_dict = rank_state_dict["model"] - - if self.clean_state_dict: - rank_state_dict = self.clean_state_dict(rank_state_dict) - - rank_state_dict = _split_state_dict(rank_state_dict) - self._aggregate(rank_state_dict["fp16_param"]) - self._aggregate(rank_state_dict["fp32_param"]) - self._aggregate(rank_state_dict["optimizer"]) - - for k in self.sharded_params: - self._reshape_tensor(k) - return self.aggregate_state_dict diff --git a/orttraining/orttraining/python/deprecated/__init__.py b/orttraining/orttraining/python/deprecated/__init__.py deleted file mode 100644 index 6e02db707bc47..0000000000000 --- a/orttraining/orttraining/python/deprecated/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -from onnxruntime.capi._pybind_state import TrainingParameters # noqa: F401 -from onnxruntime.capi.training.training_session import TrainingSession # noqa: F401 diff --git a/orttraining/orttraining/python/deprecated/training_session.py b/orttraining/orttraining/python/deprecated/training_session.py deleted file mode 100644 index a6900578e174b..0000000000000 --- a/orttraining/orttraining/python/deprecated/training_session.py +++ /dev/null @@ -1,68 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import os # noqa: F401 -import sys # noqa: F401 - -from onnxruntime.capi import _pybind_state as C -from onnxruntime.capi.onnxruntime_inference_collection import IOBinding # noqa: F401 -from onnxruntime.capi.onnxruntime_inference_collection import ( - InferenceSession, - Session, - check_and_normalize_provider_args, -) - - -class TrainingSession(InferenceSession): - def __init__(self, path_or_bytes, parameters, sess_options=None, providers=None, provider_options=None): - Session.__init__(self) - - if sess_options: - self._sess = C.TrainingSession(sess_options) - else: - self._sess = C.TrainingSession() - - # providers needs to be passed explicitly as of ORT 1.10 - # retain the pre-1.10 behavior by setting to the available providers. - if providers is None: - providers = C.get_available_providers() - - providers, provider_options = check_and_normalize_provider_args( - providers, provider_options, C.get_available_providers() - ) - - if isinstance(path_or_bytes, str): - config_result = self._sess.load_model(path_or_bytes, parameters, providers, provider_options) - elif isinstance(path_or_bytes, bytes): - config_result = self._sess.read_bytes(path_or_bytes, parameters, providers, provider_options) - else: - raise TypeError(f"Unable to load from type '{type(path_or_bytes)}'") - - self.loss_scale_input_name = config_result.loss_scale_input_name - - self._inputs_meta = self._sess.inputs_meta - self._outputs_meta = self._sess.outputs_meta - - def __del__(self): - if self._sess: - self._sess.finalize() - - def get_state(self): - return self._sess.get_state() - - def get_model_state(self, include_mixed_precision_weights=False): - return self._sess.get_model_state(include_mixed_precision_weights) - - def get_optimizer_state(self): - return self._sess.get_optimizer_state() - - def get_partition_info_map(self): - return self._sess.get_partition_info_map() - - def load_state(self, dict, strict=False): - self._sess.load_state(dict, strict) - - def is_output_fp32_node(self, output_name): - return self._sess.is_output_fp32_node(output_name) diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py deleted file mode 100644 index 5286c087cfb64..0000000000000 --- a/orttraining/orttraining/python/ort_trainer.py +++ /dev/null @@ -1,1241 +0,0 @@ -import io -import os -import warnings - -import numpy as np -import onnx -import torch -import torch.nn -import torch.onnx -from onnx import helper, numpy_helper -from packaging.version import Version as LooseVersion - -import onnxruntime as ort -import onnxruntime.capi.pt_patch -from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference - -from ..training import postprocess -from .checkpointing_utils import CombineZeroCheckpoint, get_checkpoint_name, list_checkpoint_files - -DEFAULT_OPSET_VERSION = 14 - - -class IODescription: - def __init__(self, name, shape, dtype=None, num_classes=None): - self.name_ = name - self.shape_ = shape - self.dtype_ = dtype - self.num_classes_ = num_classes - - -class ModelDescription: - def __init__(self, inputs, outputs): - self.inputs_ = inputs - self.outputs_ = outputs - - -def resolve_symbolic_dimensions(inputs, input_descs, output_descs): - import copy - - output_descs_copy = copy.deepcopy(output_descs) - resolved_dims = {} - for input, input_desc in zip(inputs, input_descs): - for i, axis in enumerate(input_desc.shape_): - if isinstance(axis, str): - resolved_dims[axis] = input.size()[i] - - for output_desc in output_descs_copy: - for i, axis in enumerate(output_desc.shape_): - if isinstance(axis, str): - output_desc.shape_[i] = resolved_dims[axis] - - if any(isinstance(axis, str) for axis in output_desc.shape_ for output_desc in output_descs): - raise RuntimeError("Cannot run model with unknown output dimensions") - - return output_descs_copy - - -def generate_sample(desc, device=None): - # symbolic dimensions are described with strings. set symbolic dimensions to be 1 - size = [s if isinstance(s, (int)) else 1 for s in desc.shape_] - if desc.num_classes_: - return torch.randint(0, desc.num_classes_, size, dtype=desc.dtype_).to(device) - else: - return torch.randn(size, dtype=desc.dtype_).to(device) - - -def get_device_index(device): - if type(device) == str: # noqa: E721 - # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 - device = torch.device(device) - return 0 if device.index is None else device.index - - -def input_get_device_index(input): - if isinstance(input, (list, tuple)): - device_index = get_device_index(input[0].device) - else: - device_index = get_device_index(input.device) - - return device_index - - -def get_all_gradients_finite_arg_name(session): - all_fp16_or_fp32_gradients_finite_node_args = [x for x in session._outputs_meta if "all_gradients_finite" in x.name] - if len(all_fp16_or_fp32_gradients_finite_node_args) < 1: - raise RuntimeError( - "Failed to find a group NodeArg with name that matches 'all_gradients_finite'\ - from the training session." - ) - - return all_fp16_or_fp32_gradients_finite_node_args[0].name - - -def get_group_accumulated_gradients_output_node_arg_name(session): - # TODO: get the constant string via pybind. - # optimizer_graph_builder BuildGroupNode with fixed string: 'Group_Accumulated_Gradients' - accumulated_gradients_output_node_args = [ - x for x in session._outputs_meta if "Group_Accumulated_Gradients" in x.name - ] - if len(accumulated_gradients_output_node_args) != 1: - raise RuntimeError( - "Failed to find a group NodeArg with name that matches 'Group_Accumulated_Gradients'\ - from the training session." - ) - - return accumulated_gradients_output_node_args[0].name - - -def ort_training_session_run_helper(session, iobinding, inputs, input_descs, output_descs, device, run_options=None): - for input, input_desc in zip(inputs, input_descs): - device_index = input_get_device_index(input) - iobinding.bind_input( - input_desc.name_, - input.device.type, - device_index, - dtype_torch_to_numpy(input.dtype), - list(input.size()), - input.data_ptr(), - ) - - output_descs_resolved = resolve_symbolic_dimensions(inputs, input_descs, output_descs) - torch_outputs = {} - for output_desc in output_descs_resolved: - torch_tensor = torch.zeros( - output_desc.shape_, - device=device, - dtype=output_desc.eval_dtype_ if hasattr(output_desc, "eval_dtype_") else output_desc.dtype_, - ) - iobinding.bind_output( - output_desc.name_, - torch_tensor.device.type, - get_device_index(device), - dtype_torch_to_numpy(torch_tensor.dtype), - list(torch_tensor.size()), - torch_tensor.data_ptr(), - ) - torch_outputs[output_desc.name_] = torch_tensor - - session.run_with_iobinding(iobinding, run_options) - return torch_outputs - - -def FuseSofmaxNLLToSoftmaxCE(onnx_model): # noqa: N802 - nll_count = 0 - while True: - nll_count = nll_count + 1 - nll_loss_node = None - nll_loss_node_index = 0 - for nll_loss_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007 - if node.op_type == "nll_loss" or node.op_type == "NegativeLogLikelihoodLoss": - nll_loss_node = node - break - - if nll_loss_node is None: - break - - softmax_node = None - softmax_node_index = 0 - label_input_name = None - weight_input_name = None - for softmax_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007 - if node.op_type == "LogSoftmax": - # has to be connected to nll_loss - if len(nll_loss_node.input) > 2: - weight_input_name = nll_loss_node.input[2] - if node.output[0] == nll_loss_node.input[0]: - softmax_node = node - label_input_name = nll_loss_node.input[1] - break - elif node.output[0] == nll_loss_node.input[1]: - softmax_node = node - label_input_name = nll_loss_node.input[0] - break - else: - if softmax_node is not None: - break - - if softmax_node is None: - break - - # delete nll_loss and LogSoftmax nodes in order - if nll_loss_node_index < softmax_node_index: - del onnx_model.graph.node[softmax_node_index] - del onnx_model.graph.node[nll_loss_node_index] - else: - del onnx_model.graph.node[nll_loss_node_index] - del onnx_model.graph.node[softmax_node_index] - - probability_output_name = softmax_node.output[0] - node = onnx_model.graph.node.add() - inputs = ( - [softmax_node.input[0], label_input_name, weight_input_name] - if weight_input_name - else [softmax_node.input[0], label_input_name] - ) - node.CopyFrom( - onnx.helper.make_node( - "SparseSoftmaxCrossEntropy", - inputs, - [nll_loss_node.output[0], probability_output_name], - "nll_loss_node_" + str(nll_count), - ) - ) - - return onnx_model - - -def delete_input_with_name(input, name): - index = 0 - for i in input: - if i.name == name: - del input[index] - break - index = index + 1 - - -# reference: -# https://docs.scipy.org/doc/numpy-1.13.0/user/basics.types.html -# https://pytorch.org/docs/stable/tensors.html -# also must map to types accepted by: -# MLDataType NumpyTypeToOnnxRuntimeType(int numpy_type) -def dtype_torch_to_numpy(torch_dtype): - if torch_dtype == torch.float64 or torch_dtype == torch.double: - return np.float64 - elif torch_dtype == torch.float32 or torch_dtype == torch.float: - return np.float32 - elif torch_dtype == torch.float16 or torch_dtype == torch.half: - return np.float16 - elif torch_dtype == torch.int64 or torch_dtype == torch.long: - return np.longlong - elif torch_dtype == torch.int32 or torch_dtype == torch.int: - return np.int32 - elif torch_dtype == torch.int16 or torch_dtype == torch.short: - return np.int16 - elif torch_dtype == torch.bool: - return bool - else: - raise Exception("Torch type to numpy type mapping unavailable for: " + str(torch_dtype)) - - -class model_loss_cls(torch.nn.Module): # noqa: N801 - def __init__(self, model, loss_fn): - super().__init__() - self.model_ = model - self.loss_fn_ = loss_fn - - def forward(self, *inputs): - # here we assume input can be unpacked into input and label - input, label = inputs[:-1], inputs[-1] - preds = self.model_(*input) - return self.loss_fn_(preds, label), preds - - -class WrapModel(torch.nn.Module): - def __init__(self, model, loss_fn, input_names): - super().__init__() - self.model_ = model - self.loss_fn_ = loss_fn - self.input_names_ = input_names - - def forward(self, *inputs): - import inspect - - # *inputs is given by torch trace. It is in the order of input_names. - # model_ takes input in a order (which can be obtained via inspect.signature(model.forward)) different than input_names. - sig = inspect.signature(self.model_.forward) - list(sig.parameters.keys()) - - input_dict = {} - for key in sig.parameters: - if key in self.input_names_: - input_dict[key] = inputs[self.input_names_.index(key)] - - model_out = self.model_(**input_dict) - if self.loss_fn_ is None: - return model_out - - label = inputs[-1] - preds = model_out - return self.loss_fn_(preds, label), preds - - -def wrap_for_input_match(model, loss_fn, input_names): - import inspect - - sig = inspect.signature(model.forward) - ordered_list_keys = list(sig.parameters.keys()) - if loss_fn: - sig_loss = inspect.signature(loss_fn) - if len(sig_loss.parameters) != 2: - raise RuntimeError("loss function should take two arguments - predict and label.") - - # label shall be the second input to loss_fn. - ordered_list_keys = [*ordered_list_keys, list(sig_loss.parameters.keys())[1]] - - # name match is needed only when input_names are a subset - # of expected inputs (inputs to model and loss_fn combined). - if len(input_names) > len(ordered_list_keys): - # this is likely the case where input arguments are packed. - # TODO: to unpack the input argument. - return model_loss_cls(model, loss_fn) if loss_fn else model - elif len(input_names) == len(ordered_list_keys): - # in this case, we do not require name match. - return model_loss_cls(model, loss_fn) if loss_fn else model - - if not all(x in ordered_list_keys for x in input_names): - # model desc has name(s) not matching the model signature. We cannot do anything in this case. - # better to warning the user. - return model_loss_cls(model, loss_fn) if loss_fn else model - - # if input_names match ordered_list_keys, there is not need for wrapping - match = True - for i, input_name in enumerate(input_names): - if input_name != ordered_list_keys[i]: - match = False - break - - if match: - return model_loss_cls(model, loss_fn) if loss_fn else model - - model = WrapModel(model, loss_fn, input_names) - - return model - - -def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, opset_version=DEFAULT_OPSET_VERSION): - # example: {input0:{0:'batch'}, input1:{0:'batch'}} - dynamic_axes = {} - for input in model_desc.inputs_: - symbolic_axis = {} - for i, axis in enumerate(input.shape_): - if isinstance(axis, str): - symbolic_axis[i] = axis - if len(symbolic_axis): - dynamic_axes[input.name_] = symbolic_axis - - for output in model_desc.outputs_: - symbolic_axis = {} - for i, axis in enumerate(output.shape_): - if isinstance(axis, str): - symbolic_axis[i] = axis - if len(symbolic_axis): - dynamic_axes[output.name_] = symbolic_axis - - input_names = [input.name_ for input in model_desc.inputs_] - output_names = [output.name_ for output in model_desc.outputs_] - - if isinstance(inputs, torch.Tensor): - inputs = [inputs] - if isinstance(inputs, dict): - sample_inputs = [inputs[k.name_].to(device=device) for k in model_desc.inputs_] - elif isinstance(inputs, (list, tuple)): - sample_inputs = [input.to(device=device) for i, input in enumerate(inputs) if i < len(model_desc.inputs_)] - else: - raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.") - - # pytorch onnx exporter/trace does not try to match argument names. - # e.g. for models with optional inputs, it requires all inputs be present. - # this is a problem because the model graph depends on inputs provided. - model = wrap_for_input_match(model, loss_fn, input_names) - - model.eval() - with torch.no_grad(): - import copy - - # Deepcopy inputs, since input values may change after model run. - sample_inputs_copy = copy.deepcopy(sample_inputs) - try: - # Deepcopy model, in case model is stateful and changes after model run. - model_copy = copy.deepcopy(model) - except Exception: - model_copy = model - warnings.warn( - "This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX." - " Compute will continue, but unexpected results may occur!" - ) - - sample_outputs = model_copy(*sample_inputs_copy) - if isinstance(sample_outputs, torch.Tensor): - sample_outputs = [sample_outputs] - for sample_output, output_desc in zip(sample_outputs, model_desc.outputs_): - output_desc.dtype_ = sample_output.dtype - model.train() - - f = io.BytesIO() - - # Other export options to use(this is for backward compatibility). - other_export_options = {} - other_export_options["training"] = True - - # This option was added after 1.4 release. - if LooseVersion(torch.__version__) > LooseVersion("1.4.0") and LooseVersion(torch.__version__) < LooseVersion( - "1.10.0" - ): - other_export_options["enable_onnx_checker"] = False - # This option was added after 1.6 release. - if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): - other_export_options["training"] = torch.onnx.TrainingMode.TRAINING - - # Deepcopy inputs, since input values may change after model run. - import copy - - sample_inputs_copy = copy.deepcopy(sample_inputs) - - # Enable contrib ops export from PyTorch - from onnxruntime.tools import pytorch_export_contrib_ops - - pytorch_export_contrib_ops.register() - - torch.onnx._export( - model, - tuple(sample_inputs_copy), - f, - input_names=input_names, - output_names=output_names, - opset_version=opset_version, - dynamic_axes=dynamic_axes, - do_constant_folding=False, - **other_export_options, - ) - - onnx_model = onnx.load_model_from_string(f.getvalue()) - - # Remove 'model_.' prefix introduced by model wrapper for initializers. - if isinstance(model, (WrapModel, model_loss_cls)): - replace_name_dict = {} - for n in onnx_model.graph.initializer: - if n.name.startswith("model_."): - replace_name_dict[n.name] = n.name[len("model_.") :] - n.name = replace_name_dict[n.name] - for n in onnx_model.graph.node: - for i, name in enumerate(n.input): - if name in replace_name_dict: - n.input[i] = replace_name_dict[name] - - return onnx_model - - -def create_ort_training_session_with_optimizer( - model, - device, - training_optimizer_name, - lr_params_feed_name, - map_optimizer_attributes, - world_rank=-1, - world_size=1, - gradient_accumulation_steps=1, - bind_parameters=False, - use_mixed_precision=False, - allreduce_post_accumulation=False, - deepspeed_zero_stage=0, - enable_grad_norm_clip=True, - frozen_weights=[], # noqa: B006 - opset_version=DEFAULT_OPSET_VERSION, - use_deterministic_compute=False, - use_memory_efficient_gradient=False, - enable_adasum=False, - optimized_model_filepath="", -): - output_name = model.graph.output[0].name - ort_parameters = ort.TrainingParameters() - ort_parameters.loss_output_name = output_name - ort_parameters.use_mixed_precision = use_mixed_precision - ort_parameters.world_rank = world_rank - ort_parameters.world_size = world_size - ort_parameters.gradient_accumulation_steps = gradient_accumulation_steps - ort_parameters.allreduce_post_accumulation = allreduce_post_accumulation - ort_parameters.deepspeed_zero_stage = deepspeed_zero_stage - ort_parameters.enable_grad_norm_clip = enable_grad_norm_clip - ort_parameters.set_gradients_as_graph_outputs = False - ort_parameters.use_memory_efficient_gradient = use_memory_efficient_gradient - ort_parameters.enable_adasum = enable_adasum - output_types = {} - for output in model.graph.output: - output_types[output.name] = output.type.tensor_type - - # pybind does not allow to add directly to ort_parameters.weights_to_train. - # Have to work around by using a temporary weights_to_train. - torch_params = {} - optimizer_attributes_map = {} - optimizer_int_attributes_map = {} - - unused_frozen_weights = [n for n in frozen_weights if n not in [i.name for i in model.graph.initializer]] - if unused_frozen_weights: - raise RuntimeError(f"{unused_frozen_weights} in frozen_weights not found in model weights.") - - weights_to_train = set() - for initializer in model.graph.initializer: - if initializer.name in frozen_weights: - continue - weights_to_train.add(initializer.name) - if map_optimizer_attributes is not None: - attributes = map_optimizer_attributes(initializer.name) - optimizer_attributes_map[initializer.name] = {} - optimizer_int_attributes_map[initializer.name] = {} - for k, v in attributes.items(): - if isinstance(v, float): - optimizer_attributes_map[initializer.name][k] = v - elif isinstance(v, int): - optimizer_int_attributes_map[initializer.name][k] = v - else: - raise ValueError("Optimizer attributes must be either float or int.") - else: - optimizer_attributes_map[initializer.name] = {} - optimizer_int_attributes_map[initializer.name] = {} - - if bind_parameters: - for initializer in model.graph.initializer: - torch_tensor = torch.nn.Parameter(torch.as_tensor(numpy_helper.to_array(initializer), device=device)) - delete_input_with_name(model.graph.input, initializer.name) - model.graph.input.extend( - [helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims)] - ) - torch_params[initializer.name] = torch_tensor - - del model.graph.initializer[:] - - ort_parameters.weights_to_train = weights_to_train - ort_parameters.training_optimizer_name = training_optimizer_name - ort_parameters.lr_params_feed_name = lr_params_feed_name - ort_parameters.optimizer_attributes_map = optimizer_attributes_map - ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map - - sessionOptions = ort.SessionOptions() # noqa: N806 - sessionOptions.use_deterministic_compute = use_deterministic_compute - if len(optimized_model_filepath) > 0: - sessionOptions.optimized_model_filepath = optimized_model_filepath - session = ort.TrainingSession(model.SerializeToString(), ort_parameters, sessionOptions) - train_io_binding = session.io_binding() - eval_io_binding = session.io_binding() - - if bind_parameters: - for param in torch_params: - torch_tensor = torch_params[param] - - train_io_binding.bind_input( - param, - torch_tensor.device.type, - get_device_index(torch_tensor.device), - dtype_torch_to_numpy(torch_params[param].dtype), - list(torch_tensor.size()), - torch_tensor.data_ptr(), - ) - eval_io_binding.bind_input( - param, - torch_tensor.device.type, - get_device_index(torch_tensor.device), - dtype_torch_to_numpy(torch_params[param].dtype), - list(torch_tensor.size()), - torch_tensor.data_ptr(), - ) - - return session, train_io_binding, eval_io_binding, output_name, torch_params, output_types - - -def save_checkpoint( - model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", checkpoint_state_dict=None, include_optimizer_state=True -): - if checkpoint_state_dict is None: - checkpoint_state_dict = {"model": model.state_dict(include_optimizer_state)} - else: - checkpoint_state_dict.update({"model": model.state_dict(include_optimizer_state)}) - - assert os.path.exists(checkpoint_dir), f"ERROR: Checkpoint directory doesn't exist: {checkpoint_dir}" - - checkpoint_name = get_checkpoint_name( - checkpoint_prefix, model.deepspeed_zero_stage_, model.world_rank, model.world_size - ) - checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) - - if os.path.exists(checkpoint_file): - warnings.warn(f"{checkpoint_file} already exists, overwriting.") - - torch.save(checkpoint_state_dict, checkpoint_file) - - -def _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict): - checkpoint_name = get_checkpoint_name(checkpoint_prefix, is_partitioned, model.world_rank, model.world_size) - checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) - - if is_partitioned: - assert_msg = ( - f"Couldn't find checkpoint file {checkpoint_file}." - "Optimizer partitioning is enabled using ZeRO. Please make sure that the " - f"checkpoint file exists for rank {model.world_rank} of {model.world_size}." - ) - else: - assert_msg = f"Couldn't find checkpoint file {checkpoint_file}." - - assert os.path.exists(checkpoint_file), assert_msg - - checkpoint_state = torch.load(checkpoint_file, map_location="cpu") - - model.load_state_dict(checkpoint_state["model"], strict=strict) - del checkpoint_state["model"] - return checkpoint_state - - -def _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict): - checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix) - - ckpt_agg = CombineZeroCheckpoint(checkpoint_files) - aggregate_state_dict = ckpt_agg.aggregate_checkpoints() - - model.load_state_dict(aggregate_state_dict, strict=strict) - - # aggregate other keys in the state_dict. - # Values will be overwritten for matching keys among workers - all_checkpoint_states = {} - for checkpoint_file in checkpoint_files: - checkpoint_state = torch.load(checkpoint_file, map_location="cpu") - del checkpoint_state["model"] - all_checkpoint_states.update(checkpoint_state) - return all_checkpoint_states - - -def load_checkpoint(model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False): - checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix) - is_partitioned = False - if len(checkpoint_files) > 1: - warnings.warn( - f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." - "Attempting to load ZeRO checkpoint." - ) - is_partitioned = True - if (not model.deepspeed_zero_stage_) and is_partitioned: - return _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict) - else: - return _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict) - - -class ORTTrainer: - def __init__( - self, - model, - loss_fn, - model_desc, - training_optimizer_name, - map_optimizer_attributes, - learning_rate_description, - device, - gradient_accumulation_steps=1, - world_rank=0, - world_size=1, - use_mixed_precision=False, - allreduce_post_accumulation=False, - global_step=0, - get_lr_this_step=None, - loss_scaler=None, - deepspeed_zero_stage=0, - enable_grad_norm_clip=True, - frozen_weights=[], # noqa: B006 - _opset_version=DEFAULT_OPSET_VERSION, - _enable_internal_postprocess=True, - _extra_postprocess=None, - _use_deterministic_compute=False, - use_memory_efficient_gradient=False, - run_symbolic_shape_infer=False, - enable_adasum=False, - optimized_model_filepath="", - ): - super().__init__() - """ - Initialize ORTTrainer. - - Args: - - model: one of - - a PyTorch model (class that inherits from torch.nn.Module) - - a combined PyTorch model and loss function. - Inputs to this combined PyTorch model are a concatenation of the - model's input and the loss function's label input. - Outputs are a concatenation of the loss function's output and the - model's output. - - a combined ONNX model and loss function. - loss_fn: one of - - a PyTorch loss function if 'model' is a PyTorch model. A loss - function takes two inputs (prediction, label) and outputs a loss - tensor. - - None if model is already combined with a loss function. - model_desc: Specify input/output shapes, types, and names. - Must be consistent with the training model. - training_optimizer_name: one of - - 'SGDOptimizer' - - 'AdamOptimizer' - - 'LambOptimizer' - map_optimizer_attributes: for optimizers with weight-dependent - parameters. A callable that maps weight name to a set of optimization - parameters. - Defaults to None. - learning_rate_description: the name, shape and type of the learning - rate in form of IODescription(Learning_Rate_Name, [1,], torch.float32). - Because learning_rate is an input to the training model, - Learning_Rate_Name must be specified so that there is no name conflict - within the model. - device: device to store tensors (e.g. 'cpu', 'cuda', 'cuda:'). - gradient_accumulation_steps: number of training steps to accumulate - gradients before averaging and applying them. - Defaults to 1. - world_rank: rank id used for distributed training. - Defaults to 0. - world_size: number of ranks participating in distributed training. - Defaults to 1. - use_mixed_precision: flag to enable mixed precision (aka fp16). - Defaults to False. - allreduce_post_accumulation: controls whether overlaping gradient - computation is applied with allreduce. - Defaults to False. - global_step: training step that is used as input to 'get_lr_this_step'. - Defaults to 0. - get_lr_this_step: functor used as learning rate scheduler. - It uses 'global_step' as input. - Defaults to None. - loss_scaler: updates loss scale automatically when 'use_mixed_precision' - is specified. - Defaults to None. - deepspeed_zero_stage: controls whether to partition state using the DeepSpeed ZeRO technique. Stages 0 and 1 are supported. - Defaults to 0 (disabled). - enable_grad_norm_clip: enables gradient norm clipping. - Defaults to True. - frozen_weights: list of model parameters to be frozen (not trained). - Defaults to []. - _enable_internal_postprocess: whether to run or not the internal postprocesses. - Defaults to True - _extra_postprocess: a callable to postprocess the ONNX model that is converted from PyTorch. - Defaults to None - use_memory_efficient_gradient: use memory aware gradient builder. - Defaults to False - run_symbolic_shape_infer: run symbolic shape inference - Defaults to False - optimized_model_filepath: path to output the optimized training graph. - Defaults to "" (no output). - """ - warnings.warn( - "ORTTrainer is deprecated and will be removed in ort release 1.14. Please use ORTModule instead.", - FutureWarning, - ) - warnings.warn( - "DISCLAIMER: This is an early version of an experimental training API and it is subject to change. DO NOT create production applications with it" - ) - self.is_train = True - - self.torch_model_ = None - self.onnx_model_ = None - self._enable_internal_postprocess = _enable_internal_postprocess - self._extra_postprocess = _extra_postprocess - - if isinstance(model, torch.nn.Module): - self.torch_model_ = model - self.loss_fn_ = loss_fn - self._torch_state_dict_keys = list(model.state_dict().keys()) - else: - self._torch_state_dict_keys = [] - self.onnx_model_ = model - if loss_fn is not None: - warnings.warn("loss_fn is not used when creating ORTTrainer because an ONNX model is provided.") - # TODO: accept loss_fn as an onnx model. build self.onnx_model_ with model and loss_fn - self.loss_fn_ = None - - if self._enable_internal_postprocess: - postprocess.run_postprocess(self.onnx_model_) - - if self._extra_postprocess: - self._extra_postprocess(self.onnx_model_) - - self.model_desc_ = model_desc - self.input_desc_with_lr = [*self.model_desc_.inputs_, learning_rate_description] - - self.world_rank = world_rank - self.world_size = world_size - self.use_mixed_precision = use_mixed_precision - - self.session = None - self.device_ = device - self.gradient_accumulation_steps = gradient_accumulation_steps - # we use self.current_step to count calls to train_step. It is used for gradient accumulation. - # gradients are being accumulated when self.current_step is not divisible by gradient_accumulation_steps. - # gradients are updated when self.current_step is divisible by gradient_accumulation_steps. - self.current_step = 0 - - # we use self.global_step_ to count optimizations being performed. - # it is used to calculate learning rate if self.get_lr_this_step_ is provided. - self.global_step_ = global_step - self.get_lr_this_step_ = get_lr_this_step - self.loss_scaler_ = loss_scaler - - if self.get_lr_this_step_ is not None or self.loss_scaler_ is not None: - warnings.warn("It is experimental to use learning rate scheduler and loss scaler inside ORTTrainer.") - self.training_optimizer_name_ = training_optimizer_name - self.learning_rate_description_ = learning_rate_description - self.map_optimizer_attributes_ = map_optimizer_attributes - self.allreduce_post_accumulation_ = allreduce_post_accumulation - self.deepspeed_zero_stage_ = deepspeed_zero_stage - self.enable_grad_norm_clip_ = enable_grad_norm_clip - self.frozen_weights_ = frozen_weights - self.opset_version_ = _opset_version - self.state_dict_ = None - self._use_deterministic_compute = _use_deterministic_compute - self.use_memory_efficient_gradient = use_memory_efficient_gradient - self.run_symbolic_shape_infer = run_symbolic_shape_infer - self.enable_adasum = enable_adasum - self.optimized_model_filepath = optimized_model_filepath - - # use this special string to workaround a corner case that external loss_scale is passed into train_step as kwargs. - # see prepare_input_and_fetches for more details. - self.loss_scale_input_name = "default_loss_scale_input_name" - - self._init_session() - - def _init_session(self): - if self.onnx_model_ is None: - return - - self._verify_fully_optimized_model(self.onnx_model_) - - if self.run_symbolic_shape_infer: - self.onnx_model_ = SymbolicShapeInference.infer_shapes( - self.onnx_model_, auto_merge=True, guess_output_rank=True - ) - - # old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error. - # for example, load_state_dict will be called before returing the function, and it calls _init_session again - del self.session - ( - self.session, - self.train_io_binding, - self.eval_io_binding, - self.output_name, - _, - self.output_types, - ) = create_ort_training_session_with_optimizer( - self.onnx_model_, - self.device_, - self.training_optimizer_name_, - self.learning_rate_description_.name_, - self.map_optimizer_attributes_, - self.world_rank, - self.world_size, - self.gradient_accumulation_steps, - bind_parameters=False, - use_mixed_precision=self.use_mixed_precision, - allreduce_post_accumulation=self.allreduce_post_accumulation_, - deepspeed_zero_stage=self.deepspeed_zero_stage_, - enable_grad_norm_clip=self.enable_grad_norm_clip_, - frozen_weights=self.frozen_weights_, - opset_version=self.opset_version_, - use_deterministic_compute=self._use_deterministic_compute, - use_memory_efficient_gradient=self.use_memory_efficient_gradient, - enable_adasum=self.enable_adasum, - optimized_model_filepath=self.optimized_model_filepath, - ) - - self.loss_scale_input_name = self.session.loss_scale_input_name - - if self.use_mixed_precision: - self.input_desc_with_lr_and_loss_scale = [ - *self.input_desc_with_lr, - IODescription(self.loss_scale_input_name, [], torch.float32), - ] - - # ORT backend has modified model output dtype from float32 to float16. - for o_desc in self.model_desc_.outputs_: - if ( - self.use_mixed_precision - and o_desc.dtype_ == torch.float32 - and not self.session.is_output_fp32_node(o_desc.name_) - ): - o_desc.eval_dtype_ = torch.float16 - else: - o_desc.eval_dtype_ = o_desc.dtype_ - - # gradient accumulation buffers are connected to a single node with a boolean, dimension 1 tensor output. - # add a matching output to drive gradient accumulation. - if self.gradient_accumulation_steps > 1: - self.output_desc_with_group_accumulated_gradients = [ - *self.model_desc_.outputs_, - IODescription(get_group_accumulated_gradients_output_node_arg_name(self.session), [1], torch.bool), - ] - - if self.use_mixed_precision: - # when ready to use accumulated gradient with mixed precision, we need to fetch all_infinite to determine - # if the gradient is usable. - self.output_desc_with_all_fp_16_or_fp32_gradients_finite = [ - *self.model_desc_.outputs_, - IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool), - ] - - if self.state_dict_: - self.load_state_dict(self.state_dict_, self.strict_) - self.state_dict_ = None - - def _init_onnx_model(self, inputs): - if self.onnx_model_ is not None: - return - - if self.torch_model_ is not None: - # NOTE: pt model is moved to cpu to conserve gpu memory. - self.torch_model_.cpu() - # torch buffers created using 'register_buffer' are not meant to be trainable. - torch_buffers = list(dict(self.torch_model_.named_buffers()).keys()) - self.frozen_weights_ = self.frozen_weights_ + torch_buffers - self.onnx_model_ = convert_model_loss_fn_to_onnx( - self.torch_model_, - self.loss_fn_, - self.model_desc_, - torch.device("cpu"), - inputs, - opset_version=self.opset_version_, - ) - - if self._enable_internal_postprocess: - postprocess.run_postprocess(self.onnx_model_) - - if self._extra_postprocess: - self._extra_postprocess(self.onnx_model_) - - self._init_session() - - def train(self): - self.is_train = True - - def eval(self): - self.is_train = False - - def _update_onnx_model_initializers(self, state_tensors): - # replace the initializers with new value - new_weights = [] - replace_indices = [] - for i, w in enumerate(self.onnx_model_.graph.initializer): - if w.name in state_tensors: - new_weights.append(numpy_helper.from_array(state_tensors[w.name], w.name)) - replace_indices.append(i) - replace_indices.sort(reverse=True) - for w_i in replace_indices: - del self.onnx_model_.graph.initializer[w_i] - self.onnx_model_.graph.initializer.extend(new_weights) - - def state_dict(self, include_optimizer_state=True): - if not self.session: - warnings.warn( - "ONNXRuntime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling state_dict()." - ) - return {} - - # extract trained weights - session_state = self.session.get_state() - torch_state = {} - for name in session_state: - torch_state[name] = torch.from_numpy(session_state[name]) - - # extract untrained weights and buffer - for n in self.onnx_model_.graph.initializer: - if n.name not in torch_state: - torch_state[n.name] = torch.from_numpy(numpy_helper.to_array(n)) - - # Need to remove redundant initializers and name suffices to map back to original torch state names - if not include_optimizer_state and self._torch_state_dict_keys: - return {key: torch_state[key] for key in self._torch_state_dict_keys if key in torch_state} - return torch_state - - def load_state_dict(self, state_dict, strict=False): - # Note: It may happen ONNX model has not yet been initialized - # In this case we cache a reference to desired state and delay the restore until after initialization - # Unexpected behavior will result if the user changes the reference before initialization - if not self.session: - self.state_dict_ = state_dict - self.strict_ = strict - return - - # update onnx model from loaded state dict - cur_initializers_names = [n.name for n in self.onnx_model_.graph.initializer] - new_initializers = {} - - for name in state_dict: - if name in cur_initializers_names: - new_initializers[name] = state_dict[name].numpy() - elif strict: - raise RuntimeError(f"Checkpoint tensor: {name} is not present in the model.") - - self._update_onnx_model_initializers(new_initializers) - - # create new session based on updated onnx model - self.state_dict_ = None - self._init_session() - - # load training state - session_state = {name: state_dict[name].numpy() for name in state_dict} - self.session.load_state(session_state, strict) - - def save_as_onnx(self, path): - if not self.session: - warnings.warn( - "ONNXRuntime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling save_as_onnx()." - ) - return - state_tensors = self.session.get_state() - self._update_onnx_model_initializers(state_tensors) - - with open(path, "wb") as f: - f.write(self.onnx_model_.SerializeToString()) - - def _prepare_input_and_fetches( - self, input_desc_with_, internal_learning_rate, internal_loss_scale, *args, **kwargs - ): - fetches = None - if type(args) == tuple and len(args) == 1 and type(args[0]) == list: # noqa: E721 - input = tuple(args[0]) - else: - input = args - - for input_desc in input_desc_with_: - if input_desc.name_ in kwargs: - input = (*input, kwargs[input_desc.name_]) - if internal_learning_rate is not None: - input = (*input, internal_learning_rate) - if internal_loss_scale is not None: - input = (*input, internal_loss_scale) - elif self.use_mixed_precision: - # loss_scale input name is needed to call train_step, for example: - # kwargs[model.loss_scale_input_name] = loss_scale - # outputs = model.train_step(*args, **kwargs) - # However, when first time train_step is called model.loss_scale_input_name is not set. - # To workaround this problem, we use the special name 'default_loss_scale_input_name' to indicate - # the loss_scale. - if "default_loss_scale_input_name" in kwargs: - input = (*input, kwargs["default_loss_scale_input_name"]) - - fetches = None - if "fetches" in kwargs: - fetches = kwargs["fetches"] - - return input, fetches - - def train_step(self, *args, **kwargs): - """ - inputs: model inputs, labels, learning rate, and, if in mixed_precision mode, loss_scale. - outputs: if fetches is not provided, outputs are loss and - (if in mixed mode and is finishing gradient accumulation) all_finite. - if fetches is provided, outputs contains these requested with fetches. - fetches: names of requested outputs - """ - - # inputs to the ONNX model includes inputs to the original PyTorch model - # plus learning rate and loss_scale if self.use_mixed_precision is True. - # 1. when there are internal learning_rate and loss_scale (in fp16 cases) generators, - # *args and **kwargs together contain ONLY and COMPLETE inputs to the PyTorch model. - # In this case, changes to the training script is minimized. - # 2. without internal learning rate and loss scale (in fp16 cases) generators, - # *args and **kwargs passed in from the training script shall contains - # inputs to the PyTorch model plus learning_rate and loss_scale. - # it optionally contains the fetches. - # localized arguments (*args) contains inputs to the ONNX model. - # named arguments can contain both inputs, learning_rate and loss_scale, and the fetches - - learning_rate, loss_scale = None, None - if self.get_lr_this_step_ is not None: - # $args, **kwargs contains inputs to the pytorch model - lr_this_step = self.get_lr_this_step_(self.global_step_) - learning_rate = torch.tensor([lr_this_step]) - if self.loss_scaler_ is not None and self.use_mixed_precision: - loss_scale = torch.tensor([self.loss_scaler_.loss_scale_]) - - if self.onnx_model_ is None: - sample_input, _ = self._prepare_input_and_fetches(self.model_desc_.inputs_, None, None, *args, **kwargs) - self._init_onnx_model(sample_input) - - if self.use_mixed_precision: - input, fetches = self._prepare_input_and_fetches( - self.input_desc_with_lr_and_loss_scale, learning_rate, loss_scale, *args, **kwargs - ) - assert len(self.input_desc_with_lr_and_loss_scale) == len(input) - input_descs = self.input_desc_with_lr_and_loss_scale - else: - input, fetches = self._prepare_input_and_fetches( - self.input_desc_with_lr, learning_rate, loss_scale, *args, **kwargs - ) - assert len(self.input_desc_with_lr) == len(input) - input_descs = self.input_desc_with_lr - - self.current_step += 1 - - # handle gradient accumulation in fully optimized mode - run_options = None - has_if_all_finite = False - if fetches: - output_desc = [output for fetch in fetches for output in self.model_desc_.outputs_ if output.name_ == fetch] - elif self.current_step % self.gradient_accumulation_steps != 0: - run_options = ort.RunOptions() - run_options.only_execute_path_to_fetches = True - output_desc = self.output_desc_with_group_accumulated_gradients - elif self.use_mixed_precision: - has_if_all_finite = True - output_desc = self.output_desc_with_all_fp_16_or_fp32_gradients_finite - else: - output_desc = self.model_desc_.outputs_ - - if not isinstance(input, (list, tuple)): - input = (input,) - - session_run_results = ort_training_session_run_helper( - self.session, self.train_io_binding, input, input_descs, output_desc, self.device_, run_options - ) - - if has_if_all_finite: - # After session run with all_fp32_gradients_finite, we need to clear the iobinding's output state. - # Otherwise next run with only_execute_path_to_fetches will lead to gradient all reduce - # because all_fp32_gradients_finite is still in the feed. - self.train_io_binding.clear_binding_outputs() - all_finite = session_run_results[self.output_desc_with_all_fp_16_or_fp32_gradients_finite[-1].name_] - if self.loss_scaler_ is not None: - self.loss_scaler_.update_loss_scale(all_finite) - if all_finite: - # optimization has done, increase self.global_step_ - self.global_step_ = self.global_step_ + 1 - elif self.current_step % self.gradient_accumulation_steps == 0: - # optimization has done, increase self.global_step_ - self.global_step_ = self.global_step_ + 1 - - if fetches is not None: - results = [session_run_results[fetch] for fetch in fetches] - elif has_if_all_finite and self.loss_scaler_ is None: - # return descripted outputs plus the all_finite flag so that the training script can handle loss scaling. - results = [ - session_run_results[output_desc.name_] - for output_desc in self.output_desc_with_all_fp_16_or_fp32_gradients_finite - ] - else: - results = [session_run_results[output_desc.name_] for output_desc in self.model_desc_.outputs_] - return results[0] if len(results) == 1 else results - - def __call__(self, *args, **kwargs): - if self.is_train: - return self.train_step(*args, **kwargs) - else: - return self.eval_step(*args, **kwargs) - - def eval_step(self, *args, **kwargs): - """ - inputs: model inputs and/or labels. - outputs: if 'fetches' is not provided, outputs are loss and - (if in mixed mode and is finishing gradient accumulation) all_finite. - if fetches is provided, outputs contains these requested with fetches. - fetches: names of requested outputs - """ - - # with model_loss_cls, the last input is label, first output is loss - input, fetches = self._prepare_input_and_fetches(self.model_desc_.inputs_, None, None, *args, **kwargs) - - if self.onnx_model_ is None: - if self.torch_model_ is not None: - self._init_onnx_model(input) - else: - raise RuntimeError( - "Model is unintialized. Please ensure a valid ONNX model or PyTorch model is provided to this Trainer." - ) - - input_desc = self.model_desc_.inputs_[0 : len(input)] - if fetches is None: - output_desc = self.model_desc_.outputs_ - else: - output_desc = [output for fetch in fetches for output in self.model_desc_.outputs_ if output.name_ == fetch] - - if not isinstance(input, (list, tuple)): - input = (input,) - - run_options = ort.RunOptions() - run_options.only_execute_path_to_fetches = True - run_options.training_mode = False - - session_run_results = ort_training_session_run_helper( - self.session, self.eval_io_binding, input, input_desc, output_desc, self.device_, run_options - ) - - if len(session_run_results) == 1: - return session_run_results[next(iter(session_run_results.keys()))] - else: - return [session_run_results[output_desc.name_] for output_desc in output_desc] - - def _verify_fully_optimized_model(self, model): - assert len(model.graph.output) > 0 - # model's first output must be the loss tensor - if model.graph.output[0].type.tensor_type.elem_type not in { - onnx.TensorProto.FLOAT, - onnx.TensorProto.FLOAT16, - onnx.TensorProto.DOUBLE, - onnx.TensorProto.COMPLEX64, - onnx.TensorProto.COMPLEX128, - onnx.TensorProto.BFLOAT16, - onnx.TensorProto.FLOAT8E4M3FN, - onnx.TensorProto.FLOAT8E4M3FNUZ, - onnx.TensorProto.FLOAT8E5M2, - onnx.TensorProto.FLOAT8E5M2FNUZ, - }: - raise RuntimeError( - "the first output of a model to run with fully optimized ORT backend must be float types." - ) - if len(model.graph.output[0].type.tensor_type.shape.dim) != 0: - raise RuntimeError( - "the first output of a model to run with fully optimized ORT backend assumed to be loss and must be a scalar." - ) - - -class LossScaler: - def __init__( - self, - loss_scale_input_name, - is_dynamic_scale, - loss_scale=float(1 << 16), - up_scale_window=2000, - min_loss_scale=1.0, - max_loss_scale=float(1 << 24), - ): - super().__init__() - self.loss_scale_input_name_ = loss_scale_input_name - self.is_dynamic_scale_ = is_dynamic_scale - self.initial_loss_scale_ = loss_scale - self.up_scale_window_ = up_scale_window - self.min_loss_scale_ = min_loss_scale - self.max_loss_scale_ = max_loss_scale - self.loss_scale_ = loss_scale - self.stable_steps_ = 0 - - def update_loss_scale(self, is_all_finite): - if not self.is_dynamic_scale_: - return - - if is_all_finite: - self.stable_steps_ += 1 - - if self.stable_steps_ >= self.up_scale_window_: - self.loss_scale_ = min(self.max_loss_scale_, self.loss_scale_ * 2) - self.stable_steps_ = 0 - else: - self.loss_scale_ = max(self.min_loss_scale_, self.loss_scale_ / 2) - self.stable_steps_ = 0 - - def reset(self): - self.loss_scale_ = self.initial_loss_scale_ - self.stable_steps_ = 0 diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index a08e8bee99cee..bb1cb4bbd32f7 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -18,7 +18,6 @@ #include "core/session/environment.h" #include "core/session/custom_ops.h" #include "core/dlpack/dlpack_converter.h" -#include "orttraining/core/session/training_session.h" #include "orttraining/core/agent/training_agent.h" #include "orttraining/core/graph/gradient_config.h" #include "orttraining/core/graph/optimizer_config.h" @@ -113,14 +112,11 @@ struct TrainingParameters { std::unordered_set weights_to_train; std::unordered_set weights_not_to_train; - onnxruntime::training::TrainingSession::ImmutableWeights immutable_weights; - // optimizer std::string training_optimizer_name; std::string lr_params_feed_name = "Learning_Rate"; std::unordered_map> optimizer_attributes_map; std::unordered_map> optimizer_int_attributes_map; - onnxruntime::training::TrainingSession::OptimizerState optimizer_initial_state; std::unordered_map> sliced_schema; std::unordered_map sliced_axes; std::vector sliced_tensor_names; @@ -206,185 +202,6 @@ struct PyGradientGraphBuilderContext { local_registries_(local_registries) {} }; -// TODO: this method does not handle parallel optimization. -TrainingConfigurationResult ConfigureSessionForTraining( - training::PipelineTrainingSession* sess, TrainingParameters& parameters) { - // TODO tix, refactor the mpi related code to populate all fields correctly by default. - ORT_ENFORCE(parameters.data_parallel_size <= parameters.world_size, "data_parallel_size: ", parameters.data_parallel_size, ", world_size: ", parameters.world_size); - ORT_ENFORCE(parameters.horizontal_parallel_size <= parameters.world_size, "horizontal_parallel_size: ", parameters.horizontal_parallel_size, ", world_size: ", parameters.world_size); - ORT_ENFORCE(parameters.pipeline_parallel_size <= parameters.world_size, "pipeline_parallel_size: ", parameters.pipeline_parallel_size, ", world_size: ", parameters.world_size); - - // When DxHxP != the total number of ranks, we try adjusting D so that DxHxP == the total number of ranks. - if (parameters.world_size != parameters.data_parallel_size * parameters.horizontal_parallel_size * parameters.pipeline_parallel_size) { - ORT_ENFORCE(parameters.world_size % parameters.horizontal_parallel_size * parameters.pipeline_parallel_size == 0, - "D, H, P sizes are incorrect. To enable automatic correction, total number of ranks must be a divisible by HxP."); - - const auto new_data_parallel_size = parameters.world_size / (parameters.horizontal_parallel_size * parameters.pipeline_parallel_size); - parameters.data_parallel_size = new_data_parallel_size; - - const std::string msg = "Cannot distribute " + std::to_string(parameters.world_size) + " ranks for distributed computation with D=" + std::to_string(parameters.data_parallel_size) + - ", H=" + std::to_string(parameters.horizontal_parallel_size) + ", P=" + std::to_string(parameters.pipeline_parallel_size) + ", so D is automatically changed to " + std::to_string(new_data_parallel_size); - LOGS(*(sess->GetLogger()), WARNING) << msg; - } - - training::PipelineTrainingSession::TrainingConfiguration config{}; - config.weight_names_to_train = parameters.weights_to_train; - config.weight_names_to_not_train = parameters.weights_not_to_train; - config.immutable_weights = parameters.immutable_weights; - config.gradient_accumulation_steps = parameters.gradient_accumulation_steps; - - config.distributed_config.world_rank = parameters.world_rank; - config.distributed_config.world_size = parameters.world_size; - config.distributed_config.local_rank = parameters.local_rank; - config.distributed_config.local_size = parameters.local_size; - config.distributed_config.data_parallel_size = parameters.data_parallel_size; - config.distributed_config.horizontal_parallel_size = parameters.horizontal_parallel_size; - config.distributed_config.pipeline_parallel_size = parameters.pipeline_parallel_size; - config.distributed_config.num_pipeline_micro_batches = parameters.num_pipeline_micro_batches; - config.distributed_config.sliced_schema = parameters.sliced_schema; - config.distributed_config.sliced_axes = parameters.sliced_axes; - config.distributed_config.sliced_tensor_names = parameters.sliced_tensor_names; - - if (parameters.use_mixed_precision) { - training::PipelineTrainingSession::TrainingConfiguration::MixedPrecisionConfiguration mp{}; - mp.use_mixed_precision_initializers = true; - - config.mixed_precision_config = mp; - } - - if (config.distributed_config.pipeline_parallel_size > 1) { - training::PipelineTrainingSession::TrainingConfiguration::PipelineConfiguration pipeline_config; - - // Currently don't support auto-partition. User needs to pass in cut information for pipeline - pipeline_config.do_partition = true; - assert(!parameters.pipeline_cut_info_string.empty()); - - auto process_with_delimiter = [](std::string& input_str, const std::string& delimiter) { - std::vector result; - size_t pos = 0; - while ((pos = input_str.find(delimiter)) != std::string::npos) { - std::string token = input_str.substr(0, pos); - result.emplace_back(token); - input_str.erase(0, pos + delimiter.length()); - } - // push the last split of substring into result. - result.emplace_back(input_str); - return result; - }; - - auto process_cut_info = [&](std::string& cut_info_string) { - std::vector cut_list; - const std::string group_delimiter = ","; - const std::string edge_delimiter = ":"; - const std::string consumer_delimiter = "/"; - const std::string producer_consumer_delimiter = "-"; - - auto cut_info_groups = process_with_delimiter(cut_info_string, group_delimiter); - for (auto& cut_info_group : cut_info_groups) { - PipelineTrainingSession::TrainingConfiguration::CutInfo cut_info; - auto cut_edges = process_with_delimiter(cut_info_group, edge_delimiter); - for (auto& cut_edge : cut_edges) { - auto process_edge = process_with_delimiter(cut_edge, producer_consumer_delimiter); - if (process_edge.size() == 1) { - PipelineTrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0]}; - cut_info.emplace_back(edge); - } else { - ORT_ENFORCE(process_edge.size() == 2); - auto consumer_list = process_with_delimiter(process_edge[1], consumer_delimiter); - - PipelineTrainingSession::TrainingConfiguration::CutEdge edge{process_edge[0], consumer_list}; - cut_info.emplace_back(edge); - } - } - cut_list.emplace_back(cut_info); - } - return cut_list; - }; - - pipeline_config.cut_list = process_cut_info(parameters.pipeline_cut_info_string); - config.pipeline_config = pipeline_config; - } - config.loss_name = parameters.loss_output_name; - - if (!parameters.training_optimizer_name.empty()) { - training::PipelineTrainingSession::TrainingConfiguration::OptimizerConfiguration opt{}; - opt.name = parameters.training_optimizer_name; - opt.learning_rate_input_name = parameters.lr_params_feed_name; - opt.weight_attributes_generator = [¶meters](const std::string& weight_name) { - const auto it = parameters.optimizer_attributes_map.find(weight_name); - ORT_ENFORCE( - it != parameters.optimizer_attributes_map.end(), - "Failed to find attribute map for weight ", weight_name); - return it->second; - }; - opt.weight_int_attributes_generator = [¶meters](const std::string& weight_name) { - const auto it = parameters.optimizer_int_attributes_map.find(weight_name); - ORT_ENFORCE( - it != parameters.optimizer_int_attributes_map.end(), - "Failed to find int attribute map for weight ", weight_name); - return it->second; - }; - opt.use_mixed_precision_moments = parameters.use_fp16_moments; - opt.do_all_reduce_in_mixed_precision_type = true; - // TODO: this mapping is temporary. - // For now, nccl allreduce kernel only implements for allreduce_post_accumulation - // hovorod allreduce kernel only implements for not allreduce_post_accumulation. - // eventually we will have one all reduce kernel and let opt to have - // an allreduce_post_accumulation option and remove the use_nccl option. - opt.use_nccl = parameters.allreduce_post_accumulation; - opt.deepspeed_zero = onnxruntime::training::ZeROConfig(parameters.deepspeed_zero_stage); - opt.enable_grad_norm_clip = parameters.enable_grad_norm_clip; - - // TODO reduction types - if (parameters.enable_adasum) { -#ifdef USE_CUDA - opt.adasum_reduction_type = training::AdasumReductionType::GpuHierarchicalReduction; -#else - opt.adasum_reduction_type = training::AdasumReductionType::CpuReduction; -#endif - } - - config.optimizer_config = opt; - } - - if (!parameters.optimizer_initial_state.empty()) { - config.init_optimizer_states = parameters.optimizer_initial_state; - } - - config.gradient_graph_config.use_memory_efficient_gradient = parameters.use_memory_efficient_gradient; - config.gradient_graph_config.set_gradients_as_graph_outputs = parameters.set_gradients_as_graph_outputs; - - config.graph_transformer_config.attn_dropout_recompute = parameters.attn_dropout_recompute; - config.graph_transformer_config.gelu_recompute = parameters.gelu_recompute; - config.graph_transformer_config.transformer_layer_recompute = parameters.transformer_layer_recompute; - config.graph_transformer_config.number_recompute_layers = parameters.number_recompute_layers; - config.graph_transformer_config.propagate_cast_ops_config.strategy = parameters.propagate_cast_ops_strategy; - config.graph_transformer_config.propagate_cast_ops_config.level = parameters.propagate_cast_ops_level; - config.graph_transformer_config.propagate_cast_ops_config.allow = parameters.propagate_cast_ops_allow; - - if (!parameters.model_after_graph_transforms_path.empty()) { - config.model_after_graph_transforms_path = ToPathString(parameters.model_after_graph_transforms_path); - } - if (!parameters.model_with_gradient_graph_path.empty()) { - config.model_with_gradient_graph_path = ToPathString(parameters.model_with_gradient_graph_path); - } - if (!parameters.model_with_training_graph_path.empty()) { - config.model_with_training_graph_path = ToPathString(parameters.model_with_training_graph_path); - } - - training::PipelineTrainingSession::TrainingConfigurationResult config_result{}; - - OrtPybindThrowIfError(sess->ConfigureForTraining(config, config_result)); - - TrainingConfigurationResult python_config_result{}; - if (config_result.mixed_precision_config_result.has_value()) { - const auto& mp_config_result = config_result.mixed_precision_config_result.value(); - python_config_result.loss_scale_input_name = mp_config_result.loss_scale_input_name; - } - - return python_config_result; -} - #if defined(USE_MPI) void CopyMPIContextToTrainingParameters(TrainingParameters& parameters, const logging::Logger* logger) { LOGS(*logger, INFO) << "MPIContext::GetInstance().GetWorldRank(): " << MPIContext::GetInstance().GetWorldRank(); @@ -424,7 +241,7 @@ std::unordered_map> Con return py_tensor_state; } -void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn) { +void addObjectMethodsForTraining(py::module& m) { py::class_(m, "OrtValueCache") .def(py::init<>()) .def("insert", [](const OrtValueCachePtr& cache_ptr, std::string node_arg_name, OrtValue& value) { @@ -451,7 +268,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn py::class_ parameters(m, "TrainingParameters", R"pbdoc(Configuration information for training.)pbdoc"); parameters.def(py::init()) .def_readwrite("loss_output_name", &TrainingParameters::loss_output_name) - .def_readwrite("immutable_weights", &TrainingParameters::immutable_weights) .def_readwrite("weights_not_to_train", &TrainingParameters::weights_not_to_train) .def_readwrite("weights_to_train", &TrainingParameters::weights_to_train) .def_readwrite("sliced_tensor_names", &TrainingParameters::sliced_tensor_names) @@ -484,25 +300,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn .def_readwrite("data_parallel_size", &TrainingParameters::data_parallel_size) .def_readwrite("horizontal_parallel_size", &TrainingParameters::horizontal_parallel_size) .def_readwrite("pipeline_parallel_size", &TrainingParameters::pipeline_parallel_size) - .def("set_optimizer_initial_state", - [](TrainingParameters& parameters, const std::unordered_map>& py_state) -> void { - onnxruntime::training::TrainingSession::OptimizerState optim_state; - for (const auto& weight_it : py_state) { - auto state = weight_it.second; - NameMLValMap state_tensors; - for (auto& initializer : state) { - OrtValue ml_value; - - // InputDeflist is null because parameters havent been tied to session yet - // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) - CreateGenericMLValue(nullptr, GetAllocator(), "", initializer.second, &ml_value, true); - ThrowIfPyErrOccured(); - state_tensors.emplace(initializer.first, ml_value); - } - optim_state.emplace(weight_it.first, state_tensors); - } - parameters.optimizer_initial_state = optim_state; - }) .def_readwrite("model_after_graph_transforms_path", &TrainingParameters::model_after_graph_transforms_path) .def_readwrite("model_with_gradient_graph_path", &TrainingParameters::model_with_gradient_graph_path) .def_readwrite("model_with_training_graph_path", &TrainingParameters::model_with_training_graph_path) @@ -611,130 +408,6 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn }); #endif - py::class_ config_result(m, "TrainingConfigurationResult", "pbdoc(Configuration result for training.)pbdoc"); - config_result.def(py::init()) - .def_property_readonly("loss_scale_input_name", [](const TrainingConfigurationResult& result) -> py::object { - if (result.loss_scale_input_name.has_value()) { - return py::str{result.loss_scale_input_name.value()}; - } - return py::none(); - }); - - // Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user - struct PyTrainingSession : public PyInferenceSession { - PyTrainingSession(std::shared_ptr env, const PySessionOptions& so) - : PyInferenceSession(env, std::make_unique(so.value, *env)) { - } - ~PyTrainingSession() = default; - }; - - py::class_ training_session(m, "TrainingSession"); - training_session - .def(py::init([](const PySessionOptions& so) { - auto& training_env = GetTrainingEnv(); - return std::make_unique(training_env.GetORTEnv(), so); - })) - .def(py::init([]() { - auto& training_env = GetTrainingEnv(); - return std::make_unique(training_env.GetORTEnv(), GetDefaultCPUSessionOptions()); - })) - .def("finalize", [](py::object) { -#if defined(USE_MPI) -#ifdef _WIN32 - // https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-best-practices - // shutdown_mpi() is not called within MPIContext destructor because of DllMain's restriction - // call shutdown_mpi() here instead. - MPIContext::shutdown_mpi(); -#endif -#endif - }) - .def("load_model", [ep_registration_fn](PyTrainingSession* sess, const std::string& path, TrainingParameters& parameters, const std::vector& provider_types, const ProviderOptionsVector& provider_options) { - OrtPybindThrowIfError(sess->GetSessionHandle()->Load(path)); - -#if defined(USE_MPI) - bool use_nccl = parameters.allreduce_post_accumulation; - if (!use_nccl && parameters.world_size > 1) - CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger()); -#endif - const auto config_result = ConfigureSessionForTraining(static_cast(sess->GetSessionHandle()), parameters); - - ProviderOptionsVector merged_options; - ResolveExtraProviderOptions(provider_types, provider_options, merged_options); - - InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options); - - return config_result; - }) - .def("read_bytes", [ep_registration_fn](PyTrainingSession* sess, const py::bytes& serialized_model, TrainingParameters& parameters, const std::vector& provider_types, const ProviderOptionsVector& provider_options) { - std::istringstream buffer(serialized_model); - OrtPybindThrowIfError(sess->GetSessionHandle()->Load(buffer)); - -#if defined(USE_MPI) - bool use_nccl = parameters.allreduce_post_accumulation; - if (!use_nccl && parameters.world_size > 1) - CopyMPIContextToTrainingParameters(parameters, sess->GetSessionHandle()->GetLogger()); -#endif - const auto config_result = ConfigureSessionForTraining(static_cast(sess->GetSessionHandle()), parameters); - ProviderOptionsVector merged_options; - ResolveExtraProviderOptions(provider_types, provider_options, merged_options); - - InitializeSession(sess->GetSessionHandle(), ep_registration_fn, provider_types, merged_options); - - return config_result; - }) - .def("get_state", [](PyTrainingSession* sess) { - NameMLValMap state_tensors; - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->GetStateTensors(state_tensors)); - auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager(); - // convert to numpy array - std::map rmap; - for (auto& kv : state_tensors) { - if (kv.second.IsTensor()) { - py::object obj; - const Tensor& rtensor = kv.second.Get(); - GetPyObjFromTensor(rtensor, obj, &data_transfer_manager); - rmap.insert({kv.first, obj}); - } else { - throw std::runtime_error("Non tensor type in session state tensors is not expected."); - } - } - return rmap; - }) - .def("get_model_state", [](PyTrainingSession* sess, bool include_mixed_precision_weights) { - std::unordered_map model_state_tensors; - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->GetModelState(model_state_tensors, include_mixed_precision_weights)); - auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager(); - return ConvertORTTensorMapToNumpy(model_state_tensors, data_transfer_manager); - }) - .def("get_optimizer_state", [](PyTrainingSession* sess) { - std::unordered_map opt_state_tensors; - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->GetOptimizerState(opt_state_tensors)); - auto& data_transfer_manager = sess->GetSessionHandle()->GetDataTransferManager(); - return ConvertORTTensorMapToNumpy(opt_state_tensors, data_transfer_manager); - }) - .def("get_partition_info_map", [](PyTrainingSession* sess) { - std::unordered_map>> part_info_map; - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->GetPartitionInfoMap(part_info_map)); - return part_info_map; - }) - .def("load_state", [](PyTrainingSession* sess, std::unordered_map& state, bool strict) { - NameMLValMap state_tensors; - for (auto initializer : state) { - OrtValue ml_value; - auto px = sess->GetSessionHandle()->GetModelInputs(); - if (!px.first.IsOK() || !px.second) { - throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null"); - } - CreateGenericMLValue(px.second, GetAllocator(), initializer.first, initializer.second, &ml_value); - ThrowIfPyErrOccured(); - state_tensors.insert(std::make_pair(initializer.first, ml_value)); - } - ORT_THROW_IF_ERROR(static_cast(sess->GetSessionHandle())->SetStateTensors(state_tensors, strict)); - }) - .def("is_output_fp32_node", [](PyTrainingSession* sess, const std::string& output_name) { - return static_cast(sess->GetSessionHandle())->IsGraphOutputFp32Node(output_name); - }); - py::class_(m, "PartialGraphExecutionState") .def(py::init([]() { return std::make_unique(); diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index 88ef90a7feaa8..4d1db7334f280 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -40,7 +40,7 @@ const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM* void addGlobalMethods(py::module& m); void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn); -void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn); +void addObjectMethodsForTraining(py::module& m); void addObjectMethodsForEager(py::module& m); #ifdef ENABLE_LAZY_TENSOR void addObjectMethodsForLazyTensor(py::module& m); @@ -339,7 +339,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { } #endif - addObjectMethodsForTraining(m, ORTTrainingRegisterExecutionProviders); + addObjectMethodsForTraining(m); #ifdef ENABLE_LAZY_TENSOR addObjectMethodsForLazyTensor(m); diff --git a/orttraining/orttraining/python/training/__init__.py b/orttraining/orttraining/python/training/__init__.py index 73b1f826f68e1..a3c22686a1039 100644 --- a/orttraining/orttraining/python/training/__init__.py +++ b/orttraining/orttraining/python/training/__init__.py @@ -8,26 +8,16 @@ TrainingParameters, is_ortmodule_available, ) -from onnxruntime.capi.training.training_session import TrainingSession - # Options need to be imported before `ORTTrainer`. -from .orttrainer_options import ORTTrainerOptions -from .orttrainer import ORTTrainer, TrainStepInfo -from . import amp, artifacts, checkpoint, model_desc_validation, optim +from . import amp, artifacts, optim __all__ = [ "PropagateCastOpsStrategy", "TrainingParameters", "is_ortmodule_available", - "TrainingSession", - "ORTTrainerOptions", - "ORTTrainer", - "TrainStepInfo", "amp", "artifacts", - "checkpoint", - "model_desc_validation", "optim", ] diff --git a/orttraining/orttraining/python/training/_checkpoint_storage.py b/orttraining/orttraining/python/training/_checkpoint_storage.py deleted file mode 100644 index 7a8ada7dee96b..0000000000000 --- a/orttraining/orttraining/python/training/_checkpoint_storage.py +++ /dev/null @@ -1,107 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import pickle -from collections.abc import Mapping - -import h5py - - -def _dfs_save(group, save_obj): - """Recursively go over each level in the save_obj dictionary and save values to a hdf5 group""" - - for key, value in save_obj.items(): - if isinstance(value, Mapping): - subgroup = group.create_group(key) - _dfs_save(subgroup, value) - else: - group[key] = value - - -def save(save_obj: dict, path): - """Persists the input dictionary to a file specified by path. - - Saves an hdf5 representation of the save_obj dictionary to a file or a file-like object specified by path. - Values are saved in a format supported by h5py. For example, a PyTorch tensor is saved and loaded as a - numpy object. So, user types may be converted from their original types to numpy equivalent types. - - Args: - save_obj: dictionary that needs to be saved. - save_obj should consist of types supported by hdf5 file format. - if hdf5 does not recognize a type, an exception is raised. - if save_obj is not a dictionary, a ValueError is raised. - path: string representation to a file path or a python file-like object. - if file already exists at path, an exception is raised. - """ - if not isinstance(save_obj, Mapping): - raise ValueError("Object to be saved must be a dictionary") - - with h5py.File(path, "w-") as f: - _dfs_save(f, save_obj) - - -def _dfs_load(group, load_obj): - """Recursively go over each level in the hdf5 group and load the values into the given dictionary""" - - for key in group: - if isinstance(group[key], h5py.Group): - load_obj[key] = {} - _dfs_load(group[key], load_obj[key]) - else: - load_obj[key] = group[key][()] - - -def load(path, key=None): - """Loads the data stored in the binary file specified at the given path into a dictionary and returns it. - - Loads the data from an hdf5 file specified at the given path into a python dictionary. - Loaded dictionary contains numpy equivalents of python data types. For example: - PyTorch tensor -> saved as a numpy array and loaded as a numpy array. - bool -> saved as a numpy bool and loaded as a numpy bool - If a '/' separated key is provided, the value at that hierarchical level in the hdf5 group is returned. - - Args: - path: string representation to a file path or a python file-like object. - if file does not already exist at path, an exception is raised. - key: '/' separated representation of the hierarchy level value that needs to be returned/ - for example, if the saved binary file has structure {a: {b: x, c:y}} and the user would like - to query the value for c, the key provided should be 'a/c'. - the default value of None for key implies that the entire hdf5 file structure needs to be loaded into a dictionary and returned. - - Returns: - a dictionary loaded from the specified binary hdf5 file. - """ - if not h5py.is_hdf5(path): - raise ValueError(f"{path} is not an hdf5 file or a python file-like object.") - - load_obj = {} - with h5py.File(path, "r") as f: - if key: - f = f[key] # noqa: PLW2901 - if isinstance(f, h5py.Dataset): - return f[()] - - _dfs_load(f, load_obj) - - return load_obj - - -def to_serialized_hex(user_dict): - """Serialize the user_dict and convert the serialized bytes to a hex string and return""" - - return pickle.dumps(user_dict).hex() - - -def from_serialized_hex(serialized_hex): - """Convert serialized_hex to bytes and deserialize it and return""" - - # serialized_hex can be either a regular string or a byte string. - # if it is a byte string, convert to regular string using decode() - # if it is a regular string, do nothing to it - try: # noqa: SIM105 - serialized_hex = serialized_hex.decode() - except AttributeError: - pass - return pickle.loads(bytes.fromhex(serialized_hex)) diff --git a/orttraining/orttraining/python/training/_utils.py b/orttraining/orttraining/python/training/_utils.py index 4eb79443c8f1a..091274d1d171d 100644 --- a/orttraining/orttraining/python/training/_utils.py +++ b/orttraining/orttraining/python/training/_utils.py @@ -6,11 +6,9 @@ import importlib.util import os import sys -from functools import wraps # noqa: F401 import numpy as np import torch -from onnx import TensorProto # noqa: F401 from packaging.version import Version @@ -23,16 +21,6 @@ def get_device_index(device): return 0 if device.index is None else device.index -def get_device_index_from_input(input): - """Returns device index from a input PyTorch Tensor""" - - if isinstance(input, (list, tuple)): - device_index = get_device_index(input[0].device) - else: - device_index = get_device_index(input.device) - return device_index - - def get_device_str(device): if isinstance(device, str): # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 @@ -50,24 +38,6 @@ def get_device_str(device): return device -def get_all_gradients_finite_name_from_session(session): - """Find all_gradients_finite node on Session graph and return its name""" - - nodes = [x for x in session._outputs_meta if "all_gradients_finite" in x.name] - if len(nodes) != 1: - raise RuntimeError("'all_gradients_finite' node not found within training session") - return nodes[0].name - - -def get_gradient_accumulation_name_from_session(session): - """Find Group_Accumulated_Gradients node on Session graph and return its name""" - - nodes = [x for x in session._outputs_meta if "Group_Accumulated_Gradients" in x.name] - if len(nodes) != 1: - raise RuntimeError("'Group_Accumulated_Gradients' node not found within training session") - return nodes[0].name - - def dtype_torch_to_numpy(torch_dtype): """Converts PyTorch types to Numpy types @@ -232,111 +202,3 @@ def import_module_from_file(file_path, module_name=None): sys.modules[module_name] = module spec.loader.exec_module(module) return module - - -def state_dict_model_key(): - """Returns the model key name in the state dictionary""" - - return "model" - - -def state_dict_optimizer_key(): - """Returns the optimizer key name in the state dictionary""" - - return "optimizer" - - -def state_dict_partition_info_key(): - """Returns the partition info key name in the state dictionary""" - - return "partition_info" - - -def state_dict_trainer_options_key(): - """Returns the trainer options key name in the state dictionary""" - - return "trainer_options" - - -def state_dict_full_precision_key(): - """Returns the full precision key name in the state dictionary""" - - return "full_precision" - - -def state_dict_original_dimension_key(): - """Returns the original dimension key name in the state dictionary""" - - return "original_dim" - - -def state_dict_sharded_optimizer_keys(): - """Returns the optimizer key names that can be sharded in the state dictionary""" - - return {"Moment_1", "Moment_2"} - - -def state_dict_user_dict_key(): - """Returns the user dict key name in the state dictionary""" - - return "user_dict" - - -def state_dict_trainer_options_mixed_precision_key(): - """Returns the trainer options mixed precision key name in the state dictionary""" - - return "mixed_precision" - - -def state_dict_trainer_options_zero_stage_key(): - """Returns the trainer options zero_stage key name in the state dictionary""" - - return "zero_stage" - - -def state_dict_trainer_options_world_rank_key(): - """Returns the trainer options world_rank key name in the state dictionary""" - - return "world_rank" - - -def state_dict_trainer_options_world_size_key(): - """Returns the trainer options world_size key name in the state dictionary""" - - return "world_size" - - -def state_dict_trainer_options_data_parallel_size_key(): - """Returns the trainer options data_parallel_size key name in the state dictionary""" - - return "data_parallel_size" - - -def state_dict_trainer_options_horizontal_parallel_size_key(): - """Returns the trainer options horizontal_parallel_size key name in the state dictionary""" - - return "horizontal_parallel_size" - - -def state_dict_trainer_options_optimizer_name_key(): - """Returns the trainer options optimizer_name key name in the state dictionary""" - - return "optimizer_name" - - -def state_dict_train_step_info_key(): - """Returns the train step info key name in the state dictionary""" - - return "train_step_info" - - -def state_dict_train_step_info_optimization_step_key(): - """Returns the train step info optimization step key name in the state dictionary""" - - return "optimization_step" - - -def state_dict_train_step_info_step_key(): - """Returns the train step info step key name in the state dictionary""" - - return "step" diff --git a/orttraining/orttraining/python/training/checkpoint.py b/orttraining/orttraining/python/training/checkpoint.py deleted file mode 100644 index d0ff0650662b7..0000000000000 --- a/orttraining/orttraining/python/training/checkpoint.py +++ /dev/null @@ -1,748 +0,0 @@ -import os -import tempfile -import warnings -from enum import Enum - -import numpy as np -import onnx -import torch - -from . import _checkpoint_storage, _utils - -################################################################################ -# Experimental Checkpoint APIs -################################################################################ - - -def experimental_state_dict(ort_trainer, include_optimizer_state=True): - warnings.warn( - "experimental_state_dict() will be deprecated soon. Please use ORTTrainer.state_dict() instead.", - DeprecationWarning, - ) - - if not ort_trainer._training_session: - warnings.warn( - "ONNX Runtime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling state_dict()." - ) - return ort_trainer._state_dict - - # extract trained weights - session_state = ort_trainer._training_session.get_state() - torch_state = {} - for name in session_state: - torch_state[name] = torch.from_numpy(session_state[name]) - - # extract untrained weights and buffer - for n in ort_trainer._onnx_model.graph.initializer: - if n.name not in torch_state and n.name in ort_trainer.options.utils.frozen_weights: - torch_state[n.name] = torch.from_numpy(np.array(onnx.numpy_helper.to_array(n))) - - # Need to remove redundant (optimizer) initializers to map back to original torch state names - if not include_optimizer_state and ort_trainer._torch_state_dict_keys: - return {key: torch_state[key] for key in ort_trainer._torch_state_dict_keys if key in torch_state} - return torch_state - - -def experimental_load_state_dict(ort_trainer, state_dict, strict=False): - warnings.warn( - "experimental_load_state_dict() will be deprecated soon. Please use ORTTrainer.load_state_dict() instead.", - DeprecationWarning, - ) - - # Note: It may happen ONNX model has not yet been initialized - # In this case we cache a reference to desired state and delay the restore until after initialization - # Unexpected behavior will result if the user changes the reference before initialization - if not ort_trainer._training_session: - ort_trainer._state_dict = state_dict - ort_trainer._load_state_dict_strict = strict - return - - # Update onnx model from loaded state dict - cur_initializers_names = [n.name for n in ort_trainer._onnx_model.graph.initializer] - new_initializers = {} - - for name in state_dict: - if name in cur_initializers_names: - new_initializers[name] = state_dict[name].numpy() - elif strict: - raise RuntimeError(f"Checkpoint tensor: {name} is not present in the model.") - - ort_trainer._update_onnx_model_initializers(new_initializers) - - # create new session based on updated onnx model - ort_trainer._state_dict = None - ort_trainer._init_session() - - # load training state - session_state = {name: state_dict[name].numpy() for name in state_dict} - ort_trainer._training_session.load_state(session_state, strict) - - -def experimental_save_checkpoint( - ort_trainer, - checkpoint_dir, - checkpoint_prefix="ORT_checkpoint", - checkpoint_state_dict=None, - include_optimizer_state=True, -): - warnings.warn( - "experimental_save_checkpoint() will be deprecated soon. Please use ORTTrainer.save_checkpoint() instead.", - DeprecationWarning, - ) - - if checkpoint_state_dict is None: - checkpoint_state_dict = {"model": experimental_state_dict(ort_trainer, include_optimizer_state)} - else: - checkpoint_state_dict.update({"model": experimental_state_dict(ort_trainer, include_optimizer_state)}) - - assert os.path.exists(checkpoint_dir), f"checkpoint_dir ({checkpoint_dir}) directory doesn't exist" - - checkpoint_name = _get_checkpoint_name( - checkpoint_prefix, - ort_trainer.options.distributed.deepspeed_zero_optimization.stage, - ort_trainer.options.distributed.world_rank, - ort_trainer.options.distributed.world_size, - ) - checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) - if os.path.exists(checkpoint_file): - msg = f"{checkpoint_file} already exists, overwriting." - warnings.warn(msg) - torch.save(checkpoint_state_dict, checkpoint_file) - - -def experimental_load_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False): - warnings.warn( - "experimental_load_checkpoint() will be deprecated soon. Please use ORTTrainer.load_checkpoint() instead.", - DeprecationWarning, - ) - - checkpoint_files = _list_checkpoint_files(checkpoint_dir, checkpoint_prefix) - is_partitioned = False - if len(checkpoint_files) > 1: - msg = ( - f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." - " Attempting to load ZeRO checkpoint." - ) - warnings.warn(msg) - is_partitioned = True - if (not ort_trainer.options.distributed.deepspeed_zero_optimization.stage) and is_partitioned: - return _load_multi_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, strict) - else: - return _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict) - - -class _AGGREGATION_MODE(Enum): # noqa: N801 - Zero = 0 - Megatron = 1 - - -def _order_paths(paths, D_groups, H_groups): - """Reorders the given paths in order of aggregation of ranks for D and H parallellism respectively - and returns the ordered dict""" - - trainer_options_path_tuples = [] - world_rank = _utils.state_dict_trainer_options_world_rank_key() - - for path in paths: - trainer_options_path_tuples.append( - (_checkpoint_storage.load(path, key=_utils.state_dict_trainer_options_key()), path) - ) - - # sort paths according to rank - sorted_paths = [ - path - for _, path in sorted( - trainer_options_path_tuples, key=lambda trainer_options_path_pair: trainer_options_path_pair[0][world_rank] - ) - ] - - ordered_paths = dict() - ordered_paths["D"] = [[sorted_paths[i] for i in D_groups[group_id]] for group_id in range(len(D_groups))] - ordered_paths["H"] = [[sorted_paths[i] for i in H_groups[group_id]] for group_id in range(len(H_groups))] - - return ordered_paths - - -def _add_or_update_sharded_key( - state_key, state_value, state_sub_dict, model_state_key, state_partition_info, sharded_states_original_dims, mode -): - """Add or update the record for the sharded state_key in the state_sub_dict""" - - # record the original dimension for this state - original_dim = _utils.state_dict_original_dimension_key() - sharded_states_original_dims[model_state_key] = state_partition_info[original_dim] - - axis = 0 - if mode == _AGGREGATION_MODE.Megatron and state_partition_info["megatron_row_partition"] == 0: - axis = -1 - - if state_key in state_sub_dict: - # state_dict already contains a record for this state - # since this state is sharded, concatenate the state value to - # the record in the state_dict - state_sub_dict[state_key] = np.concatenate((state_sub_dict[state_key], state_value), axis) - else: - # create a new entry for this state in the state_dict - state_sub_dict[state_key] = state_value - - -def _add_or_validate_unsharded_key(state_key, state_value, state_sub_dict, mismatch_error_string): - """Add or validate the record for the unsharded state_key in the state_sub_dict""" - - if state_key in state_sub_dict: - # state_dict already contains a record for this unsharded state. - # assert that all values are the same for this previously loaded state - assert (state_sub_dict[state_key] == state_value).all(), mismatch_error_string - else: - # create a new entry for this state in the state_sub_dict - state_sub_dict[state_key] = state_value - - -def _aggregate_model_states( - rank_state_dict, sharded_states_original_dims, state_dict, mixed_precision_enabled, mode=_AGGREGATION_MODE.Zero -): - """Aggregates all model states from the rank_state_dict into state_dict""" - - model = _utils.state_dict_model_key() - full_precision = _utils.state_dict_full_precision_key() - partition_info = _utils.state_dict_partition_info_key() - - # if there are no model states in the rank_state_dict, no model aggregation is needed - if model not in rank_state_dict: - return - - if model not in state_dict: - state_dict[model] = {} - - if full_precision not in state_dict[model]: - state_dict[model][full_precision] = {} - - # iterate over all model state keys - for model_state_key, model_state_value in rank_state_dict[model][full_precision].items(): - # ZERO: full precision model states are sharded only when they exist in the partition_info subdict and mixed - # precision training was enabled. for full precision training, full precision model states are not sharded - # MEGATRON : full precision model states are sharded when they exist in the partition_info subdict - if (model_state_key in rank_state_dict[partition_info]) and ( - mode == _AGGREGATION_MODE.Megatron or mixed_precision_enabled - ): - # this model state is sharded - _add_or_update_sharded_key( - model_state_key, - model_state_value, - state_dict[model][full_precision], - model_state_key, - rank_state_dict[partition_info][model_state_key], - sharded_states_original_dims, - mode, - ) - else: - # this model state is not sharded since a record for it does not exist in the partition_info subdict - _add_or_validate_unsharded_key( - model_state_key, - model_state_value, - state_dict[model][full_precision], - f"Value mismatch for model state {model_state_key}", - ) - - -def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode=_AGGREGATION_MODE.Zero): - """Aggregates all optimizer states from the rank_state_dict into state_dict""" - - optimizer = _utils.state_dict_optimizer_key() - partition_info = _utils.state_dict_partition_info_key() - sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys() - - # if there are no optimizer states in the rank_state_dict, no optimizer aggregation is needed - if optimizer not in rank_state_dict: - return - - if optimizer not in state_dict: - state_dict[optimizer] = {} - - # iterate over all optimizer state keys - for model_state_key, optimizer_dict in rank_state_dict[optimizer].items(): - for optimizer_key, optimizer_value in optimizer_dict.items(): - if model_state_key not in state_dict[optimizer]: - state_dict[optimizer][model_state_key] = {} - - if optimizer_key in sharded_optimizer_keys and model_state_key in rank_state_dict[partition_info]: - # this optimizer state is sharded since a record exists in the partition_info subdict - _add_or_update_sharded_key( - optimizer_key, - optimizer_value, - state_dict[optimizer][model_state_key], - model_state_key, - rank_state_dict[partition_info][model_state_key], - sharded_states_original_dims, - mode, - ) - else: - # this optimizer state is not sharded since a record for it does not exist in the partition_info subdict - # or this optimizer key is not one of the sharded optimizer keys - _add_or_validate_unsharded_key( - optimizer_key, - optimizer_value, - state_dict[optimizer][model_state_key], - f"Value mismatch for model state {model_state_key} and optimizer state {optimizer_key}", - ) - - -def _reshape_states(sharded_states_original_dims, state_dict, mixed_precision_enabled): - """Reshape model and optimizer states in the state_dict according to dimensions in sharded_states_original_dims""" - - model = _utils.state_dict_model_key() - full_precision = _utils.state_dict_full_precision_key() - optimizer = _utils.state_dict_optimizer_key() - sharded_optimizer_keys = _utils.state_dict_sharded_optimizer_keys() - - for sharded_state_key, original_dim in sharded_states_original_dims.items(): - # reshape model states to original_dim only when mixed precision is enabled - if mixed_precision_enabled and (model in state_dict): - state_dict[model][full_precision][sharded_state_key] = state_dict[model][full_precision][ - sharded_state_key - ].reshape(original_dim) - - # reshape optimizer states to original_dim - if optimizer in state_dict: - for optimizer_key, optimizer_value in state_dict[optimizer][sharded_state_key].items(): - if optimizer_key in sharded_optimizer_keys: - state_dict[optimizer][sharded_state_key][optimizer_key] = optimizer_value.reshape(original_dim) - - -def _aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation): - """Extracts trainer options from rank_state_dict and loads them accordingly on state_dict""" - trainer_options = _utils.state_dict_trainer_options_key() - state_dict[trainer_options] = {} - - mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key() - zero_stage = _utils.state_dict_trainer_options_zero_stage_key() - world_rank = _utils.state_dict_trainer_options_world_rank_key() - world_size = _utils.state_dict_trainer_options_world_size_key() - optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key() - D_size = _utils.state_dict_trainer_options_data_parallel_size_key() # noqa: N806 - H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() # noqa: N806 - - state_dict[trainer_options][mixed_precision] = rank_state_dict[trainer_options][mixed_precision] - state_dict[trainer_options][zero_stage] = 0 - state_dict[trainer_options][world_rank] = rank_state_dict[trainer_options][world_rank] if partial_aggregation else 0 - state_dict[trainer_options][world_size] = 1 - state_dict[trainer_options][optimizer_name] = rank_state_dict[trainer_options][optimizer_name] - state_dict[trainer_options][D_size] = 1 - state_dict[trainer_options][H_size] = 1 - - -def _aggregate_megatron_partition_info(rank_state_dict, state_dict): - """Extracts partition_info from rank_state_dict and loads on state_dict for megatron-partitioned weights""" - partition_info = _utils.state_dict_partition_info_key() - if partition_info not in state_dict: - state_dict[partition_info] = {} - - rank_partition_info = rank_state_dict[partition_info] - for model_state_key, partition_info_dict in rank_partition_info.items(): - if model_state_key not in state_dict[partition_info]: - # add partition info only if weight is megatron partitioned - if partition_info_dict["megatron_row_partition"] >= 0: - state_dict[partition_info][model_state_key] = partition_info_dict - - -def _to_pytorch_format(state_dict): - """Convert ORT state dictionary schema (hierarchical structure) to PyTorch state dictionary schema (flat structure)""" - - pytorch_state_dict = {} - for model_state_key, model_state_value in state_dict[_utils.state_dict_model_key()][ - _utils.state_dict_full_precision_key() - ].items(): - # convert numpy array to a torch tensor - pytorch_state_dict[model_state_key] = torch.tensor(model_state_value) - return pytorch_state_dict - - -def _get_parallellism_groups(data_parallel_size, horizontal_parallel_size, world_size): - """Returns the D and H groups for the given sizes""" - num_data_groups = world_size // data_parallel_size - data_groups = [] - for data_group_id in range(num_data_groups): - data_group_ranks = [] - for r in range(data_parallel_size): - data_group_ranks.append(data_group_id + horizontal_parallel_size * r) - data_groups.append(data_group_ranks) - - num_horizontal_groups = world_size // horizontal_parallel_size - horizontal_groups = [] - for hori_group_id in range(num_horizontal_groups): - hori_group_ranks = [] - for r in range(horizontal_parallel_size): - hori_group_ranks.append(hori_group_id * horizontal_parallel_size + r) - horizontal_groups.append(hori_group_ranks) - - return data_groups, horizontal_groups - - -def _aggregate_over_ranks( - ordered_paths, - ranks, - sharded_states_original_dims=None, - mode=_AGGREGATION_MODE.Zero, - partial_aggregation=False, - pytorch_format=True, -): - """Aggregate checkpoint files over set of ranks and return a single state dictionary - - Args: - ordered_paths: list of paths in the order in which they must be aggregated - ranks: list of ranks that are to be aggregated - sharded_states_original_dims: dict containing the original dims for sharded states that are persisted over - multiple calls to _aggregate_over_ranks() - mode: mode of aggregation: Zero or Megatron - partial_aggregation: boolean flag to indicate whether to produce a partially - aggregated state which can be further aggregated over - pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict - Returns: - state_dict that can be loaded into an ORTTrainer or into a PyTorch model - """ - state_dict = {} - if sharded_states_original_dims is None: - sharded_states_original_dims = dict() - world_rank = _utils.state_dict_trainer_options_world_rank_key() - mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key() - zero_stage = _utils.state_dict_trainer_options_zero_stage_key() - world_size = _utils.state_dict_trainer_options_world_size_key() - optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key() - - loaded_mixed_precision = None - loaded_world_size = None - loaded_zero_stage = None - loaded_optimizer_name = None - - for i, path in enumerate(ordered_paths): - rank_state_dict = _checkpoint_storage.load(path) - - assert _utils.state_dict_partition_info_key() in rank_state_dict, "Missing information: partition_info" - assert _utils.state_dict_trainer_options_key() in rank_state_dict, "Missing information: trainer_options" - assert ( - ranks[i] == rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank] - ), "Unexpected rank in file at path {}. Expected {}, got {}".format( - path, rank, rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank] # noqa: F821 - ) - if loaded_mixed_precision is None: - loaded_mixed_precision = rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] - else: - assert ( - loaded_mixed_precision == rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] - ), f"Mixed precision state mismatch among checkpoint files. File: {path}" - if loaded_world_size is None: - loaded_world_size = rank_state_dict[_utils.state_dict_trainer_options_key()][world_size] - else: - assert ( - loaded_world_size == rank_state_dict[_utils.state_dict_trainer_options_key()][world_size] - ), f"World size state mismatch among checkpoint files. File: {path}" - if loaded_zero_stage is None: - loaded_zero_stage = rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage] - else: - assert ( - loaded_zero_stage == rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage] - ), f"Zero stage mismatch among checkpoint files. File: {path}" - if loaded_optimizer_name is None: - loaded_optimizer_name = rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] - else: - assert ( - loaded_optimizer_name == rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] - ), f"Optimizer name mismatch among checkpoint files. File: {path}" - - # aggregate all model states - _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, loaded_mixed_precision, mode) - - if not pytorch_format: - # aggregate all optimizer states if pytorch_format is False - _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode) - - # for D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups - # to aggregate over Zero, and another pass to aggregate Megatron partitioned - # states. Preserve the relevant partition info only for weights that are megatron partitioned for - # a partial aggregation call - if partial_aggregation: - _aggregate_megatron_partition_info(rank_state_dict, state_dict) - - # entry for trainer_options in the state_dict to perform other sanity checks - if _utils.state_dict_trainer_options_key() not in state_dict: - _aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation) - - # entry for user_dict in the state_dict if not already present - if ( - _utils.state_dict_user_dict_key() not in state_dict - and _utils.state_dict_user_dict_key() in rank_state_dict - ): - state_dict[_utils.state_dict_user_dict_key()] = rank_state_dict[_utils.state_dict_user_dict_key()] - - # for a partial aggregation scenario, we might not have the entire tensor aggregated yet, thus skip reshape - if not partial_aggregation: - # reshape all the sharded tensors based on the original dimensions stored in sharded_states_original_dims - _reshape_states(sharded_states_original_dims, state_dict, loaded_mixed_precision) - - # return a flat structure for PyTorch model in case pytorch_format is True - # else return the hierarchical structure for ORTTrainer - return _to_pytorch_format(state_dict) if pytorch_format else state_dict - - -def _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format): # noqa: N802 - """Aggregate checkpoint files and return a single state dictionary for the D+H - (Zero+Megatron) partitioning strategy. - For D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups - to aggregate over Zero, and another pass over the previously aggregated states - to aggregate Megatron partitioned states. - """ - sharded_states_original_dims = {} - aggregate_data_checkpoint_files = [] - - # combine for Zero over data groups and save to temp file - with tempfile.TemporaryDirectory() as save_dir: - for group_id, d_group in enumerate(D_groups): - aggregate_state_dict = _aggregate_over_ranks( - ordered_paths["D"][group_id], - d_group, - sharded_states_original_dims, - partial_aggregation=True, - pytorch_format=False, - ) - - filename = "ort.data_group." + str(group_id) + ".ort.pt" - filepath = os.path.join(save_dir, filename) - _checkpoint_storage.save(aggregate_state_dict, filepath) - aggregate_data_checkpoint_files.append(filepath) - - assert len(aggregate_data_checkpoint_files) > 0 - - # combine for megatron: - aggregate_state = _aggregate_over_ranks( - aggregate_data_checkpoint_files, - H_groups[0], - sharded_states_original_dims, - mode=_AGGREGATION_MODE.Megatron, - pytorch_format=pytorch_format, - ) - - return aggregate_state - - -def aggregate_checkpoints(paths, pytorch_format=True): - """Aggregate checkpoint files and return a single state dictionary - - Aggregates checkpoint files specified by paths and loads them one at a time, merging - them into a single state dictionary. - The checkpoint files represented by paths must be saved through ORTTrainer.save_checkpoint() function. - The schema of the state_dict returned will be in the same as the one returned by ORTTrainer.state_dict() - - Args: - paths: list of more than one file represented as strings where the checkpoint is saved - pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict - Returns: - state_dict that can be loaded into an ORTTrainer or into a PyTorch model - """ - - loaded_trainer_options = _checkpoint_storage.load(paths[0], key=_utils.state_dict_trainer_options_key()) - D_size = _utils.state_dict_trainer_options_data_parallel_size_key() # noqa: N806 - H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() # noqa: N806 - world_size = _utils.state_dict_trainer_options_world_size_key() - - D_size = loaded_trainer_options[D_size] # noqa: N806 - H_size = loaded_trainer_options[H_size] # noqa: N806 - world_size = loaded_trainer_options[world_size] - D_groups, H_groups = _get_parallellism_groups(D_size, H_size, world_size) # noqa: N806 - - combine_zero = loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 - combine_megatron = len(H_groups[0]) > 1 - - # order the paths in the order of groups in which they must be aggregated according to - # data-parallel groups and H-parallel groups obtained - # eg: {'D': [[path_0, path_2],[path_1, path_3]], 'H': [[path_0, path_1],[path_2, path_3]]} - ordered_paths = _order_paths(paths, D_groups, H_groups) - - aggregate_state = None - if combine_zero and combine_megatron: - aggregate_state = _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format) - elif combine_zero: - aggregate_state = _aggregate_over_ranks( - ordered_paths["D"][0], D_groups[0], mode=_AGGREGATION_MODE.Zero, pytorch_format=pytorch_format - ) - elif combine_megatron: - aggregate_state = _aggregate_over_ranks( - ordered_paths["H"][0], H_groups[0], mode=_AGGREGATION_MODE.Megatron, pytorch_format=pytorch_format - ) - - return aggregate_state - - -################################################################################ -# Helper functions -################################################################################ - - -def _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict): - checkpoint_name = _get_checkpoint_name( - checkpoint_prefix, - is_partitioned, - ort_trainer.options.distributed.world_rank, - ort_trainer.options.distributed.world_size, - ) - checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) - - if is_partitioned: - assert_msg = ( - f"Couldn't find checkpoint file {checkpoint_file}." - " Optimizer partitioning is enabled using ZeRO. Please make sure the checkpoint file exists " - f"for rank {ort_trainer.options.distributed.world_rank} of {ort_trainer.options.distributed.world_size}" - ) - else: - assert_msg = f"Couldn't find checkpoint file {checkpoint_file}." - assert os.path.exists(checkpoint_file), assert_msg - - checkpoint_state = torch.load(checkpoint_file, map_location="cpu") - experimental_load_state_dict(ort_trainer, checkpoint_state["model"], strict=strict) - del checkpoint_state["model"] - return checkpoint_state - - -def _load_multi_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, strict): - checkpoint_files = _list_checkpoint_files(checkpoint_dir, checkpoint_prefix) - - ckpt_agg = _CombineZeroCheckpoint(checkpoint_files) - aggregate_state_dict = ckpt_agg.aggregate_checkpoints() - - experimental_load_state_dict(ort_trainer, aggregate_state_dict, strict=strict) - - # aggregate other keys in the state_dict. - # Values will be overwritten for matching keys among workers - all_checkpoint_states = dict() - for checkpoint_file in checkpoint_files: - checkpoint_state = torch.load(checkpoint_file, map_location="cpu") - del checkpoint_state["model"] - all_checkpoint_states.update(checkpoint_state) - return all_checkpoint_states - - -def _list_checkpoint_files(checkpoint_dir, checkpoint_prefix, extension=".ort.pt"): - ckpt_file_names = [f for f in os.listdir(checkpoint_dir) if f.startswith(checkpoint_prefix)] - ckpt_file_names = [f for f in ckpt_file_names if f.endswith(extension)] - ckpt_file_names = [os.path.join(checkpoint_dir, f) for f in ckpt_file_names] - - assert len(ckpt_file_names) > 0, f"No checkpoint found with prefix '{checkpoint_prefix}' at '{checkpoint_dir}'" - return ckpt_file_names - - -def _get_checkpoint_name(prefix, is_partitioned, world_rank=None, world_size=None): - SINGLE_CHECKPOINT_FILENAME = "{prefix}.ort.pt" # noqa: N806 - MULTIPLE_CHECKPOINT_FILENAME = "{prefix}.ZeRO.{world_rank}.{world_size}.ort.pt" # noqa: N806 - - if is_partitioned: - filename = MULTIPLE_CHECKPOINT_FILENAME.format( - prefix=prefix, world_rank=world_rank, world_size=(world_size - 1) - ) - else: - filename = SINGLE_CHECKPOINT_FILENAME.format(prefix=prefix) - return filename - - -def _split_state_dict(state_dict): - optimizer_keys = ["Moment_1_", "Moment_2_", "Update_Count_", "Step"] - split_sd = {"optimizer": {}, "fp32_param": {}, "fp16_param": {}} - for k, v in state_dict.items(): - mode = "fp32_param" - for optim_key in optimizer_keys: - if k.startswith(optim_key): - mode = "optimizer" - break - if k.endswith("_fp16"): - mode = "fp16_param" - split_sd[mode][k] = v - return split_sd - - -class _CombineZeroCheckpoint: - def __init__(self, checkpoint_files, clean_state_dict=None): - assert len(checkpoint_files) > 0, "No checkpoint files passed" - self.checkpoint_files = checkpoint_files - self.clean_state_dict = clean_state_dict - self.world_size = int(self.checkpoint_files[0].split("ZeRO")[1].split(".")[2]) + 1 - assert len(self.checkpoint_files) == self.world_size, f"Could not find {self.world_size} files" - self.weight_shape_map = {} - self.sharded_params = set() - - def _split_name(self, name: str): - name_split = name.split("_view_") - view_num = None - if len(name_split) > 1: - view_num = int(name_split[1]) - optimizer_key = "" - mp_suffix = "" - if name_split[0].startswith("Moment_1"): - optimizer_key = "Moment_1_" - elif name_split[0].startswith("Moment_2"): - optimizer_key = "Moment_2_" - elif name_split[0].startswith("Update_Count"): - optimizer_key = "Update_Count_" - elif name_split[0].endswith("_fp16"): - mp_suffix = "_fp16" - param_name = name_split[0] - if optimizer_key: - param_name = param_name.split(optimizer_key)[1] - param_name = param_name.split("_fp16")[0] - return param_name, optimizer_key, view_num, mp_suffix - - def _update_weight_statistics(self, name, value): - if name not in self.weight_shape_map: - self.weight_shape_map[name] = value.size() # original shape of tensor - - def _reshape_tensor(self, key): - value = self.aggregate_state_dict[key] - weight_name, _, _, _ = self._split_name(key) - set_size = self.weight_shape_map[weight_name] - self.aggregate_state_dict[key] = value.reshape(set_size) - - def _aggregate(self, param_dict): - for k, v in param_dict.items(): - weight_name, optimizer_key, view_num, mp_suffix = self._split_name(k) - if view_num is not None: - # parameter is sharded - param_name = optimizer_key + weight_name + mp_suffix - - if param_name in self.aggregate_state_dict and optimizer_key not in ["Update_Count_"]: - self.sharded_params.add(param_name) - # Found a previous shard of the param, concatenate shards ordered by ranks - self.aggregate_state_dict[param_name] = torch.cat((self.aggregate_state_dict[param_name], v)) - else: - self.aggregate_state_dict[param_name] = v - else: - if k in self.aggregate_state_dict: - assert (self.aggregate_state_dict[k] == v).all(), "Unsharded params must have the same value" - else: - self.aggregate_state_dict[k] = v - self._update_weight_statistics(weight_name, v) - - def aggregate_checkpoints(self): - warnings.warn( - "_CombineZeroCheckpoint.aggregate_checkpoints() will be deprecated soon. " - "Please use aggregate_checkpoints() instead.", - DeprecationWarning, - ) - - checkpoint_prefix = self.checkpoint_files[0].split(".ZeRO")[0] - self.aggregate_state_dict = dict() - - for i in range(self.world_size): - checkpoint_name = _get_checkpoint_name(checkpoint_prefix, True, i, self.world_size) - rank_state_dict = torch.load(checkpoint_name, map_location=torch.device("cpu")) - if "model" in rank_state_dict: - rank_state_dict = rank_state_dict["model"] - - if self.clean_state_dict: - rank_state_dict = self.clean_state_dict(rank_state_dict) - - rank_state_dict = _split_state_dict(rank_state_dict) - self._aggregate(rank_state_dict["fp16_param"]) - self._aggregate(rank_state_dict["fp32_param"]) - self._aggregate(rank_state_dict["optimizer"]) - - for k in self.sharded_params: - self._reshape_tensor(k) - return self.aggregate_state_dict diff --git a/orttraining/orttraining/python/training/model_desc_validation.py b/orttraining/orttraining/python/training/model_desc_validation.py deleted file mode 100644 index dd3f4cb95cd59..0000000000000 --- a/orttraining/orttraining/python/training/model_desc_validation.py +++ /dev/null @@ -1,408 +0,0 @@ -from collections import namedtuple - -import cerberus -import torch - -from ._utils import static_vars - -LEARNING_RATE_IO_DESCRIPTION_NAME = "__learning_rate" -ALL_FINITE_IO_DESCRIPTION_NAME = "__all_finite" -LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME = "__loss_scale_input_name" -GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME = "__gradient_accumulation_name" - - -class _ORTTrainerModelDesc: - def __init__(self, model_desc): - # Keep a copy of original input for debug - self._original = dict(model_desc) - - # Global counter used to validate occurrences of 'is_loss=True' whithin 'model_desc.outputs' - # A stateless validator is used for each tuple, but validation accross the whole list of tuple is needed - # because just one 'is_loss=True' is allowed withing 'model_desc.outputs' list of tuples - _model_desc_outputs_validation.loss_counter = 0 - - # Used for logging purposes - self._main_class_name = self.__class__.__name__ - - # Validates user input - self._validated = dict(self._original) - validator = cerberus.Validator(MODEL_DESC_SCHEMA) - self._validated = validator.validated(self._validated) - if self._validated is None: - raise ValueError(f"Invalid model_desc: {validator.errors}") - - # Normalize inputs to a list of namedtuple(name, shape) - self._InputDescription = namedtuple("InputDescription", ["name", "shape"]) - self._InputDescriptionTyped = namedtuple("InputDescriptionTyped", ["name", "shape", "dtype"]) - for idx, input in enumerate(self._validated["inputs"]): - self._validated["inputs"][idx] = self._InputDescription(*input) - - # Normalize outputs to a list of namedtuple(name, shape, is_loss) - self._OutputDescription = namedtuple("OutputDescription", ["name", "shape", "is_loss"]) - self._OutputDescriptionTyped = namedtuple( - "OutputDescriptionTyped", ["name", "shape", "is_loss", "dtype", "dtype_amp"] - ) - for idx, output in enumerate(self._validated["outputs"]): - if len(output) == 2: - self._validated["outputs"][idx] = self._OutputDescription(*output, False) - else: - self._validated["outputs"][idx] = self._OutputDescription(*output) - - # Hard-code learning rate, all_finite descriptors - self.learning_rate = self._InputDescriptionTyped(LEARNING_RATE_IO_DESCRIPTION_NAME, [1], torch.float32) - - # Convert dict in object - for k, v in self._validated.items(): - setattr(self, k, self._wrap(v)) - - def __repr__(self): - """Pretty representation for a model description class""" - - pretty_msg = "Model description:\n" - - # Inputs - inputs = [] - for i_desc in self.inputs: - if isinstance(i_desc, self._InputDescription): - inputs.append(f"(name={i_desc.name}, shape={i_desc.shape})") - elif isinstance(i_desc, self._InputDescriptionTyped): - inputs.append(f"(name={i_desc.name}, shape={i_desc.shape}, dtype={i_desc.dtype})") - else: - raise ValueError(f"Unexpected type {type(i_desc)} for input description") - - pretty_msg += "\nInputs:" - for idx, item in enumerate(inputs): - pretty_msg += f"\n\t{idx}: {item}" - - # Outputs - outputs = [] - for o_desc in self.outputs: - if isinstance(o_desc, self._OutputDescription): - outputs.append(f"(name={o_desc.name}, shape={o_desc.shape})") - elif isinstance(o_desc, self._OutputDescriptionTyped): - outputs.append( - f"(name={o_desc.name}, shape={o_desc.shape}, dtype={o_desc.dtype}, dtype_amp={o_desc.dtype_amp})" - ) - else: - raise ValueError(f"Unexpected type {type(o_desc)} for output description") - pretty_msg += "\nOutputs:" - for idx, item in enumerate(outputs): - pretty_msg += f"\n\t{idx}: {item}" - - # Learning rate - if self.learning_rate: - pretty_msg += "\nLearning rate: " - pretty_msg += ( - f"(name={self.learning_rate.name}, shape={self.learning_rate.shape}, dtype={self.learning_rate.dtype})" - ) - - # Mixed precision - if getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None) or getattr( - self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None - ): - pretty_msg += "\nMixed Precision:" - if getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None): - pretty_msg += "\n\tis gradients finite: " - pretty_msg += ( - f"(name={self.all_finite.name}, shape={self.all_finite.shape}, dtype={self.all_finite.dtype})" - ) - if getattr(self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None): - pretty_msg += "\n\tloss scale input name: " - pretty_msg += f"(name={self.loss_scale_input.name}, shape={self.loss_scale_input.shape}, dtype={self.loss_scale_input.dtype})" - - # Gradient Accumulation steps - if self.gradient_accumulation: - pretty_msg += "\nGradient Accumulation: " - pretty_msg += f"(name={self.gradient_accumulation.name}, shape={self.gradient_accumulation.shape}, dtype={self.gradient_accumulation.dtype})" - - return pretty_msg - - def add_type_to_input_description(self, index, dtype): - """Updates an existing input description at position 'index' with 'dtype' type information - - Args: - index (int): position within 'inputs' description - dtype (torch.dtype): input data type - """ - - assert isinstance(index, int) and index >= 0, "input 'index' must be a positive int" - assert isinstance(dtype, torch.dtype), "input 'dtype' must be a torch.dtype type" - existing_values = (*self.inputs[index],) - if isinstance(self.inputs[index], self._InputDescriptionTyped): - existing_values = (*existing_values[:-1],) - self.inputs[index] = self._InputDescriptionTyped(*existing_values, dtype) - - def add_type_to_output_description(self, index, dtype, dtype_amp=None): - """Updates an existing output description at position 'index' with 'dtype' type information - - Args: - index (int): position within 'inputs' description - dtype (torch.dtype): input data type - dtype_amp (torch.dtype, default is None): input data type for evaluation with mixed precision - """ - - assert isinstance(index, int) and index >= 0, "output 'index' must be a positive int" - assert isinstance(dtype, torch.dtype), "output 'dtype' must be a torch.dtype type" - assert dtype_amp is None or isinstance( - dtype_amp, torch.dtype - ), "output 'dtype_amp' must be either None or torch.dtype type" - existing_values = (*self.outputs[index],) - if isinstance(self.outputs[index], self._OutputDescriptionTyped): - existing_values = (*existing_values[:-2],) - self.outputs[index] = self._OutputDescriptionTyped(*existing_values, dtype, dtype_amp) - - @property - def gradient_accumulation(self): - return getattr(self, GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME, None) - - @gradient_accumulation.setter - def gradient_accumulation(self, name): - self._add_output_description( - self, name, [1], False, torch.bool, None, GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME, ignore_duplicate=True - ) - - @property - def all_finite(self): - return getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None) - - @all_finite.setter - def all_finite(self, name): - self._add_output_description( - self, name, [1], False, torch.bool, None, ALL_FINITE_IO_DESCRIPTION_NAME, ignore_duplicate=True - ) - - @property - def loss_scale_input(self): - return getattr(self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None) - - @loss_scale_input.setter - def loss_scale_input(self, name): - self._add_input_description( - self, name, [], torch.float32, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, ignore_duplicate=True - ) - - def _add_input_description(self, node, name, shape, dtype=None, attr_name=None, ignore_duplicate=False): - """Add a new input description into the node object - - If 'dtype' is specified, a typed input description namedtuple(name, shape, dtype) is created. - Otherwise an untyped input description namedtuple(name, shape) is created instead. - - Args: - node (list or object): node to append input description to. When 'node' is 'self.inputs', - a new input description is appended to the list. - Otherwise, a new input description is created as an attribute into 'node' with name 'attr_name' - name (str): name of input description - shape (list): shape of input description - dtype (torch.dtype): input data type - attr_name (str, default is None): friendly name to allow direct access to the output description - ignore_duplicate (bool, default is False): silently skips addition of duplicate inputs - """ - - assert isinstance(name, str) and len(name) > 0, "'name' is an invalid input name" - not_found = True - if not ignore_duplicate: - if id(node) == id(self.inputs): - not_found = all([name not in i_desc.name for i_desc in node]) - assert not_found, f"'name' {name} already exists in the inputs description" - else: - not_found = attr_name not in dir(self) - assert not_found, f"'attr_name' {attr_name} already exists in the 'node'" - elif not not_found: - return - assert isinstance(shape, list) and all( - [(isinstance(dim, int) or (isinstance(dim, str) and len(dim) > 0)) for dim in shape] - ), "'shape' must be a list of int or str with length at least 1" - assert dtype is None or isinstance(dtype, torch.dtype), "'dtype' must be either None or a torch.dtype type" - if dtype: - new_input_desc = self._InputDescriptionTyped(name, shape, dtype) - else: - new_input_desc = self._InputDescription(name, shape) - - if id(node) == id(self.inputs): - self.inputs.append(new_input_desc) - else: - assert isinstance(attr_name, str) and len(attr_name) > 0, "Invalid 'attr_name'" - setattr(node, attr_name, new_input_desc) - - def _add_output_description( - self, node, name, shape, is_loss, dtype=None, dtype_amp=None, attr_name=None, ignore_duplicate=False - ): - """Add a new output description into the node object as a tuple - - When (name, shape, is_loss, dtype) is specified, a typed output description is created - Otherwise an untyped output description (name, shape, is_loss) is created instead - - Args: - node (list or object): node to append output description to. When 'node' is 'self.outputs', - a new output description is appended to the list. - Otherwise, a new output description is created as an attribute into 'node' with name 'attr_name' - name (str): name of output description - shape (list): shape of output description - is_loss (bool): specifies whether this output is a loss - dtype (torch.dtype): input data type - dtype_amp (torch.dtype, default is None): input data type for evaluation with mixed precision. - attr_name (str, default is None): friendly name to allow direct access to the output description - ignore_duplicate (bool, default is False): silently skips addition of duplicate outputs - """ - - assert isinstance(name, str) and len(name) > 0, "'name' is an invalid output name" - assert isinstance(shape, list) and all( - [(isinstance(dim, int) or (isinstance(dim, str) and len(dim) > 0)) for dim in shape] - ), "'shape' must be a list of int or str with length at least 1" - assert isinstance(is_loss, bool), "'is_loss' must be a bool" - - not_found = True - if not ignore_duplicate: - if id(node) == id(self.outputs): - not_found = all([name not in o_desc.name for o_desc in node]) - assert not_found, f"'name' {name} already exists in the outputs description" - assert ( - all([not o_desc.is_loss for o_desc in node]) if is_loss else True - ), "Only one 'is_loss' is supported at outputs description" - else: - not_found = attr_name not in dir(self) - assert not_found, f"'attr_name' {attr_name} already exists in the 'node'" - elif not not_found: - return - - assert dtype is None or isinstance(dtype, torch.dtype), "'dtype' must be either None or a torch.dtype type" - if dtype: - new_output_desc = self._OutputDescriptionTyped(name, shape, is_loss, dtype, None) - else: - new_output_desc = self._OutputDescription(name, shape, is_loss) - - if id(node) == id(self.outputs): - self.outputs.append(new_output_desc) - else: - assert isinstance(attr_name, str) and len(attr_name) > 0, "Invalid 'attr_name'" - setattr(node, attr_name, new_output_desc) - - def _wrap(self, v): - """Add 'v' as self's attribute to allow direct access as self.v""" - if isinstance(v, (list)): - return type(v)([self._wrap(v) for v in v]) - elif isinstance( - v, - ( - self._InputDescription, - self._InputDescriptionTyped, - self._OutputDescription, - self._OutputDescriptionTyped, - ), - ): - return v - elif isinstance(v, (tuple)): - return type(v)([self._wrap(v) for v in v]) - elif isinstance(v, (dict, int, float, bool, str)): - return _ORTTrainerModelDescInternal(self._main_class_name, v) if isinstance(v, dict) else v - else: - raise ValueError( - f"Unsupported type for model_desc ({v})." - "Only int, float, bool, str, list, tuple and dict are supported" - ) - - -class _ORTTrainerModelDescInternal(_ORTTrainerModelDesc): - r"""Internal class used by ONNX Runtime training backend for input validation - - NOTE: Users MUST NOT use this class in any way! - """ - - def __init__(self, main_class_name, model_desc): - # Used for logging purposes - self._main_class_name = main_class_name - - # Convert dict in object - for k, v in dict(model_desc).items(): - setattr(self, k, self._wrap(v)) - - -def _model_desc_inputs_validation(field, value, error): - r"""Cerberus custom check method for 'model_desc.inputs' - - 'model_desc.inputs' is a list of tuples. - The list has variable length, but each tuple has size 2 - - The first element of the tuple is a string which represents the input name - The second element is a list of shapes. Each shape must be either an int or string. - Empty list represents a scalar output - - Validation is done within each tuple to enforce the schema described above. - - Example: - - .. code-block:: python - - model_desc['inputs'] = [('input1', ['batch', 1024]), - ('input2', []) - ('input3', [512])] - """ - - if not isinstance(value, tuple) or len(value) != 2: - error(field, "must be a tuple with size 2") - if not isinstance(value[0], str): - error(field, "the first element of the tuple (aka name) must be a string") - if not isinstance(value[1], list): - error(field, "the second element of the tuple (aka shape) must be a list") - else: - for shape in value[1]: - if not isinstance(shape, str) and not isinstance(shape, int) or isinstance(shape, bool): - error(field, "each shape must be either a string or integer") - - -@static_vars(loss_counter=0) -def _model_desc_outputs_validation(field, value, error): - r"""Cerberus custom check method for 'model_desc.outputs' - - 'model_desc.outputs' is a list of tuples with variable length. - The first element of the tuple is a string which represents the output name - The second element is a list of shapes. Each shape must be either an int or string. - Empty list represents a scalar output - The third element is optional and is a flag that signals whether the output is a loss value - - Validation is done within each tuple to enforce the schema described above, but also - throughout the list of tuples to ensure a single 'is_loss=True' occurrence. - - Example: - - .. code-block:: python - - model_desc['outputs'] = [('output1', ['batch', 1024], is_loss=True), - ('output2', [], is_loss=False) - ('output3', [512])] - """ - - if not isinstance(value, tuple) or len(value) < 2 or len(value) > 3: - error(field, "must be a tuple with size 2 or 3") - if len(value) == 3 and not isinstance(value[2], bool): - error(field, "the third element of the tuple (aka is_loss) must be a boolean") - elif len(value) == 3: - if value[2]: - _model_desc_outputs_validation.loss_counter += 1 - if _model_desc_outputs_validation.loss_counter > 1: - error(field, "only one is_loss can bet set to True") - if not isinstance(value[0], str): - error(field, "the first element of the tuple (aka name) must be a string") - if not isinstance(value[1], list): - error(field, "the second element of the tuple (aka shape) must be a list") - else: - for shape in value[1]: - if not isinstance(shape, str) and not isinstance(shape, int) or isinstance(shape, bool): - error(field, "each shape must be either a string or integer") - - -# Validation schema for model description dictionary -MODEL_DESC_SCHEMA = { - "inputs": { - "type": "list", - "required": True, - "minlength": 1, - "schema": {"check_with": _model_desc_inputs_validation}, - }, - "outputs": { - "type": "list", - "required": True, - "minlength": 1, - "schema": {"check_with": _model_desc_outputs_validation}, - }, -} diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py deleted file mode 100644 index d5a488c436a1d..0000000000000 --- a/orttraining/orttraining/python/training/orttrainer.py +++ /dev/null @@ -1,1537 +0,0 @@ -import copy -import io -import os -import warnings -from functools import partial -from inspect import signature - -import numpy as np -import onnx -import torch - -import onnxruntime as ort -from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference - -from . import _checkpoint_storage, _utils, amp, checkpoint, optim, postprocess -from .model_desc_validation import _ORTTrainerModelDesc -from .orttrainer_options import ORTTrainerOptions - - -class TrainStepInfo: - r"""Private class used to store runtime information from current train step. - - After every train step, :py:meth:`ORTTrainer.train_step` updates the internal instance of - :py:class:`.TrainStepInfo` residing on :py:class:`.ORTTrainer` with relevant information - from the forward pass. - - This class shouldn't be accessed directly by the user, unless they really know what they are doing. - Instead, :py:class:`.ORTTrainer` passes it to relevant class methods automatically, - such as :py:method:`._LRScheduler.get_lr` or :py:class:`.LossScaler.update`. - - Args: - optimizer_config (optim._OptimizerConfig): reference to optimizer config - all_finite (bool, default is True): flag that indicates whether all gradients are still finite after last step - fetches (list of str, default is []): list of output names to fetch from train_step/eval_step. Set it to [] to reset normal behavior. - optimization_step (int): indicates the number of optimizations performed. Used for learning rate scheduling - step (int): indicates current training step. Used for gradient accumulation - - Example: - - .. code-block:: python - - info = TrainStepInfo(optimizer_config=optim.SGDConfig(lr=0.01)) - if info.all_finite: - print(f'Yay, all gradients are finite at {step} step!') - - """ - - def __init__(self, optimizer_config, all_finite=True, fetches=[], optimization_step=0, step=0): # noqa: B006 - assert isinstance(optimizer_config, optim._OptimizerConfig), "optimizer_config must be a optim._OptimizerConfig" - assert isinstance(all_finite, bool), "all_finite must be a bool" - assert isinstance(fetches, list) and all( - [isinstance(item, str) for item in fetches] - ), "fetches must be a list of str" - assert isinstance(optimization_step, int) and optimization_step >= 0, "optimization_step must be a positive int" - assert isinstance(step, int) and step >= 0, "step must be a positive int" - - self.optimizer_config = optimizer_config - self.all_finite = all_finite - self.fetches = fetches - self.optimization_step = optimization_step - self.step = step - - -class ORTTrainer: - r"""Pytorch frontend for ONNX Runtime training - - Entry point that exposes the C++ backend of ORT as a Pytorch frontend. - - Args: - model (torch.nn.Module or onnx.ModelProto): either a PyTorch or ONNX model. - When a PyTorch model and :py:attr:`loss_fn` are specified, :py:attr:`model` and :py:obj:`loss_fn` are combined. - When a ONNX model is provided, the loss is identified by the flag :py:obj:`is_loss=True` in one of the :py:attr:`.model_desc.outputs` entries. - model_desc (dict): model input and output description. - This is used to identify inputs and outputs and their shapes, so that ORT can generate back propagation graph, plan memory allocation for - training, and perform optimizations. - :py:attr:`model_desc` must be consistent with the training :py:attr:`model` and have the following (:py:obj:`dict`) schema - :py:obj:`{ 'inputs': [tuple(name, shape)], 'outputs': [tuple(name, shape, is_loss)]}`. - :py:attr:`name` is a string representing the name of input or output of the model. - For :py:obj:`model_desc['inputs']` entries, :py:attr:`name` must match input names of the original PyTorch model's :py:meth:`torch.nn.Module.forward` method. - For ONNX models, both name and order of input names must match. - For :py:obj:`model_desc['outputs']` entries, the order must match the original PyTorch's output as returned by :py:meth:`torch.nn.Module.forward` method. - For ONNX models, both name and order of output names must match. - :py:attr:`shape` is a list of string or integers that describes the shape of the input/output. - Each dimension size can be either a string or an int. String means the dimension size is dynamic, while integers mean static dimensions. - An empty list implies a scalar. - Lastly, :py:attr:`is_loss` is a boolean (default is False) that flags if this output is considered a loss. - ORT backend needs to know which output is loss in order to generate back propagation graph. - Loss output must be specified when either :py:attr:`loss_fn` is specified or when loss is embedded in the model. - Note that only one loss output is supported per model. - optimizer_config (optim._OptimizerConfig): optimizer config. - One of :py:class:`.optim.AdamConfig`, :py:class:`.optim.LambConfig` or :py:class:`.optim.SGDConfig`. - loss_fn (callable, default is None): a PyTorch loss function. - It takes two inputs [prediction, label] and outputs a scalar loss tensor. - If provided, :py:attr:`loss_fn` is combined with the PyTorch :py:attr:`model` to form a combined PyTorch model. - Inputs to the combined PyTorch model are concatenation of the :py:attr:`model`'s input and :py:attr:`loss_fn`'s label input. - Outputs of the combined PyTorch model are concatenation of :py:attr:`loss_fn`'s loss output and :py:attr:`model`'s outputs. - options (ORTTrainerOptions, default is None): options for additional features. - Example: - - .. code-block:: python - - model = ... - loss_fn = ... - model_desc = { - "inputs": [ - ("input_ids", ["batch", "max_seq_len_in_batch"]), - ("attention_mask", ["batch", "max_seq_len_in_batch"]), - ("token_type_ids", ["batch", "max_seq_len_in_batch"]), - ("masked_lm_labels", ["batch", "max_seq_len_in_batch"]), - ("next_sentence_label", ["batch", 1]) - ], - "outputs": [ - ("loss", [], True), - ], - } - optim_config = optim.LambConfig(param_groups = [ { 'params' : ['model_param0'], 'alpha' : 0.8, 'beta' : 0.7}, - { 'params' : ['model_param1' , 'model_param_2'], 'alpha' : 0.0} - ], - alpha=0.9, beta=0.999) - ort_trainer = ORTTrainer(model, model_desc, optim_config, loss_fn) - """ - - def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None): - warnings.warn( - "ORTTrainer is deprecated and will be removed in ort release 1.14. Please use ORTModule instead.", - FutureWarning, - ) - - assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model" - assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'" - assert isinstance( - optim_config, optim._OptimizerConfig - ), "'optim_config' is required and must be any of 'AdamConfig', 'LambConfig' or 'SGDConfig'" - assert loss_fn is None or ( - callable(loss_fn) and len(signature(loss_fn).parameters) == 2 - ), "'loss_fn' must be either 'None' or a callable with two parameters" - assert options is None or isinstance( - options, ORTTrainerOptions - ), "'options' must be either 'None' or 'ORTTrainerOptions'" - - # Model + Loss validation - # Supported combinarios are - # ---------------------------------------- - # | | Model | Loss | - # ---------------------------------------- - # | 1 | torch.nn.Module | None | - # | 2 | torch.nn.Module | torch.nn.Module | - # | 3 | ONNX | None | - # ---------------------------------------- - self._torch_model = None - self._onnx_model = None - if isinstance(model, torch.nn.Module): - assert loss_fn is None or isinstance( - model, torch.nn.Module - ), "'loss_fn' must be either 'None' or 'torch.nn.Module'" - self._torch_model = model - self.loss_fn = loss_fn - # TODO: Remove when experimental checkpoint functions are removed. - self._torch_state_dict_keys = list(model.state_dict().keys()) - elif isinstance(model, onnx.ModelProto): - assert loss_fn is None, "'loss_fn' must not be specified when 'model' is an ONNX model" - self._onnx_model = model - self.loss_fn = None - else: - raise ValueError("'model' must be either 'torch.nn.Module' or 'onnx.ModelProto'") - - self.model_desc = _ORTTrainerModelDesc(model_desc) - self.optim_config = optim_config - - # ORTTrainerOptions - if not options: - options = ORTTrainerOptions() - self.options = options - if self.options.mixed_precision.enabled and not self.options.mixed_precision.loss_scaler: - # TODO: Move this to model_desc_validation.py - self.options.mixed_precision.loss_scaler = amp.loss_scaler.DynamicLossScaler() - # Post processing ONNX model given as input - if self._onnx_model: - if self.options._internal_use.enable_internal_postprocess: - self._onnx_model = postprocess.run_postprocess(self._onnx_model) - if self.options._internal_use.extra_postprocess: - self._onnx_model = self.options._internal_use.extra_postprocess(self._onnx_model) - assert isinstance(self._onnx_model, onnx.ModelProto), "'extra_postprocess' must return a ONNX model" - - # When input model is already ONNX (and not exported from Pytorch within ORTTrainer), - # append 'dtype' from ONNX into model description's - for idx_i, i_desc in enumerate(self.model_desc.inputs): - dtype = None - for onnx_input in self._onnx_model.graph.input: - if onnx_input.name == i_desc.name: - dtype = _utils.dtype_onnx_to_torch(onnx_input.type.tensor_type.elem_type) - self.model_desc.add_type_to_input_description(idx_i, dtype) - break - assert dtype is not None, f"ONNX model with unknown input type ({i_desc.name})" - for idx_o, o_desc in enumerate(self.model_desc.outputs): - dtype = None - for onnx_output in self._onnx_model.graph.output: - if onnx_output.name == o_desc.name: - dtype = _utils.dtype_onnx_to_torch(onnx_output.type.tensor_type.elem_type) - self.model_desc.add_type_to_output_description(idx_o, dtype) - break - assert dtype is not None, f"ONNX model with unknown output type ({o_desc.name})" - - try: - from torch.utils.cpp_extension import ROCM_HOME - - self.is_rocm_pytorch = bool(torch.version.hip is not None and ROCM_HOME is not None) - except ImportError: - self.is_rocm_pytorch = False - - # TODO: Remove when experimental checkpoint functions are removed. - self._state_dict = {} - - self._train_step_info = TrainStepInfo(self.optim_config) - self._training_session = None - self._load_state_dict = None - self._init_session( - provider_options=self.options._validated_opts["provider_options"], - session_options=self.options.session_options, - ) - - def eval_step(self, *args, **kwargs): - r"""Evaluation step method - - Args: - *args: Arbitrary arguments that are used as model input (data only) - **kwargs: Arbitrary keyword arguments that are used as model input (data only) - - Returns: - ordered :py:obj:`list` with model outputs as described by :py:attr:`.ORTTrainer.model_desc` - """ - # Get data. CombineTorchModelLossFn takes label as last input and outputs loss first - sample_input = self._prepare_model_input(self.model_desc.inputs, None, None, *args, **kwargs) - - # Export model to ONNX - if self._onnx_model is None: - if self._torch_model is not None: - self._init_onnx_model(sample_input) - else: - raise RuntimeError("Model is uninitialized. Only ONNX and PyTorch models are supported") - - # Prepare input/output description - inputs_desc = self.model_desc.inputs - outputs_desc = self.model_desc.outputs - if self._train_step_info.fetches: - outputs_desc = [o_desc for o_desc in outputs_desc if o_desc.name in self._train_step_info.fetches] - if len(outputs_desc) != len(self._train_step_info.fetches): - raise RuntimeError("The specified fetches list contains invalid output names") - - # Normalize input - if not isinstance(sample_input, (list, tuple)): - sample_input = (sample_input,) - - # RunOptions - run_options = ort.RunOptions() - run_options.only_execute_path_to_fetches = True - run_options.training_mode = False - - # Run a eval step and return - session_run_results = self._training_session_run_helper( - False, sample_input, inputs_desc, outputs_desc, run_options - ) - - # Output must be returned in the same order as defined in the model description - results = [session_run_results[o_desc.name] for o_desc in outputs_desc] - return results[0] if len(results) == 1 else results - - def save_as_onnx(self, path): - r"""Persists ONNX model into :py:attr:`path` - - The model will be saved as a Google Protocol Buffers (aka protobuf) file as per ONNX standard. - The graph includes full information, including inference and training metadata. - - Args: - path (str): Full path, including filename, to save the ONNX model in the filesystem - - Raises: - RuntimeWarning: raised when neither `train_step` or `eval_step` was called at least once - ValueError: raised when `path` is not valid path - """ - if not self._training_session: - warnings.warn( - "Training session is not initialized yet. " - "'train_step' or 'eval_step' methods must be executed at least once before calling 'save_as_onnx()'." - ) - return - state_tensors = self._training_session.get_state() - self._update_onnx_model_initializers(state_tensors) - - assert isinstance(path, str), "'path' must be a valid path string" - dir_name = os.path.dirname(path) - file_name = os.path.basename(path) - if (dir_name and not os.path.exists(dir_name)) or not file_name: - warnings.warn("'path' is not valid or does not exist") - return - - with open(path, "wb") as f: - f.write(self._onnx_model.SerializeToString()) - - def _check_model_export(self, input): - from numpy.testing import assert_allclose - from onnx import TensorProto, helper, numpy_helper # noqa: F401 - - onnx_model_copy = copy.deepcopy(self._onnx_model) - - # Mute the dropout nodes - dropout_nodes = [n for n in onnx_model_copy.graph.node if n.op_type == "Dropout"] - for node in dropout_nodes: - ratio_node = next(n for n in onnx_model_copy.graph.node if node.input[1] in n.output) - training_mode_node = next(n for n in onnx_model_copy.graph.node if node.input[2] in n.output) - - training_mode_node.attribute.pop() - ratio_node.attribute.pop() - new_training_mode_arr = np.array(False, dtype=bool) - new_ratio_arr = np.array(0.0, dtype=np.float32) - new_training_mode = numpy_helper.from_array(new_training_mode_arr) - new_ratio = numpy_helper.from_array(new_ratio_arr) - training_mode_node.attribute.add().t.CopyFrom(new_training_mode) - ratio_node.attribute.add().t.CopyFrom(new_ratio) - training_mode_node.attribute[0].type = 4 - ratio_node.attribute[0].type = 4 - training_mode_node.attribute[0].name = "value" - ratio_node.attribute[0].name = "value" - - _inference_sess = ort.InferenceSession( - onnx_model_copy.SerializeToString(), providers=ort.get_available_providers() - ) - inf_inputs = {} - for i, input_elem in enumerate(input): - inf_inputs[_inference_sess.get_inputs()[i].name] = input_elem.cpu().numpy() - _inference_outs = _inference_sess.run(None, inf_inputs) - for torch_item, ort_item in zip(self.torch_sample_outputs, _inference_outs): - assert_allclose( - torch_item, - ort_item, - rtol=1e-2, - atol=1e-6, - err_msg="Mismatch between outputs of PyTorch model and exported ONNX model. " - "Note that different backends may exhibit small computational differences." - "If this is within acceptable margin, or if there is random generator " - "in the model causing inevitable mismatch, you can proceed training by " - "setting the flag debug.check_model_export to False.", - ) - - def train_step(self, *args, **kwargs): - r"""Train step method - - After forward pass, an ordered list with all outputs described at :py:attr:`ORTTrainer.model_desc` is returned. - Additional information relevant to the train step is maintend by :py:attr:`ORTTrainer._train_step_info`. - See :py:class:`.TrainStepInfo` for details. - - Args: - *args: Arbitrary arguments that are used as model input (data only) - **kwargs: Arbitrary keyword arguments that are used as model input (data only) - - Returns: - ordered :py:obj:`list` with model outputs as described by :py:attr:`ORTTrainer.model_desc` - """ - # Export model to ONNX - if self._onnx_model is None: - sample_input = self._prepare_model_input(self.model_desc.inputs, None, None, *args, **kwargs) - self._init_onnx_model(sample_input) - - # Debug Model Export if indicated - if self.options.debug.check_model_export: - self._check_model_export(sample_input) - - # Prepare inputs+lr and output descriptions - inputs_desc = self._model_desc_inputs_with_lr - outputs_desc = self.model_desc.outputs - - # Train step must be incremented *before* gradient accumulation code - # Gradients are accumulated when - # self._train_step_info.step % self.options.batch.gradient_accumulation_steps != 0, - # and they are updated otherwise - self._train_step_info.step += 1 - - # RunOptions - run_options = None - mixed_precision_without_fetches = False - if self._train_step_info.fetches: - outputs_desc = [o_desc for o_desc in outputs_desc if o_desc.name in self._train_step_info.fetches] - if len(outputs_desc) != len(self._train_step_info.fetches): - raise RuntimeError("The specified fetches list contains invalid output names") - elif self._train_step_info.step % self.options.batch.gradient_accumulation_steps != 0: - run_options = ort.RunOptions() - run_options.only_execute_path_to_fetches = True - outputs_desc = self._model_desc_outputs_with_gradient_accumulation - elif self.options.mixed_precision.enabled: - mixed_precision_without_fetches = True - outputs_desc = self._model_desc_outputs_with_all_finite - - # Update Learning Rate if Necessary - lr = self.optim_config.lr - if self.options.lr_scheduler: - lr = self.options.lr_scheduler._step(self._train_step_info)[0] - - # Loss Scale for mixed precision - loss_scale = None - if self.options.mixed_precision.enabled: - loss_scaler = self.options.mixed_precision.loss_scaler - assert loss_scaler, "Loss scaler is required when mixed precision is enabled" - loss_scale = loss_scaler.loss_scale - inputs_desc = self._model_desc_inputs_with_lr_and_loss_scale - - # Get data. CombineTorchModelLossFn takes label as last input and outputs loss first - input = self._prepare_model_input(inputs_desc, lr, loss_scale, *args, **kwargs) - - # Normalize input - if not isinstance(args, (list, tuple)): - args = (args,) - - # Run a train step and return - session_run_results = self._training_session_run_helper(True, input, inputs_desc, outputs_desc, run_options) - if mixed_precision_without_fetches: - # After session run with all_fp32_gradients_finite, we need to clear the training I/O binding's output - # Otherwise next run with only_execute_path_to_fetches will lead to gradient all reduce - # because all_fp32_gradients_finite is still in the feed. - self._train_io_binding.clear_binding_outputs() - - is_all_finite = session_run_results[self.model_desc.all_finite.name] - self._train_step_info.all_finite = is_all_finite - if loss_scaler: - loss_scaler.update(self._train_step_info) - if is_all_finite: - # Optimization step must be incremented *after* optimization is successful - self._train_step_info.optimization_step += 1 - elif self._train_step_info.step % self.options.batch.gradient_accumulation_steps == 0: - # Optimization step must be incremented *after* optimization is successful - self._train_step_info.optimization_step += 1 - - # Output must be returned in the same order as defined in the model description - # or in the order specified by TrainStepInfo.fetches, if applicable - if self._train_step_info.fetches: - results = [session_run_results[o_desc] for o_desc in self._train_step_info.fetches] - else: - results = [session_run_results[o_desc.name] for o_desc in self.model_desc.outputs] - return results[0] if len(results) == 1 else results - - def _convert_torch_model_loss_fn_to_onnx(self, inputs, device): - # Dynamic axes - dynamic_axes = {} - for input in self.model_desc.inputs: - symbolic_axis = {} - for i, axis in enumerate(input.shape): - if isinstance(axis, str): - symbolic_axis[i] = axis - if len(symbolic_axis): - dynamic_axes[input.name] = symbolic_axis - for output in self.model_desc.outputs: - symbolic_axis = {} - for i, axis in enumerate(output.shape): - if isinstance(axis, str): - symbolic_axis[i] = axis - if len(symbolic_axis): - dynamic_axes[output.name] = symbolic_axis - - if isinstance(inputs, torch.Tensor): - inputs = [inputs] - if isinstance(inputs, dict): - sample_inputs = [inputs[k.name_].to(device=device) for k in self.model_desc.inputs] - elif isinstance(inputs, (list, tuple)): - sample_inputs = [ - input.to(device=device) for i, input in enumerate(inputs) if i < len(self.model_desc.inputs) - ] - else: - raise RuntimeError( - "Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported." - ) - - # PyTorch ONNX exporter does not match argument names - # This is an issue because the ONNX graph depends on all inputs to be specified - - # Validate loss_fn - if self.loss_fn: - sig_loss = signature(self.loss_fn) - if len(sig_loss.parameters) != 2: - raise RuntimeError("loss function should take two arguments - predict and label.") - - # Basic input names from model - input_names = [input.name for input in self.model_desc.inputs] - sig = signature(self._torch_model.forward) - ordered_input_list = list(sig.parameters.keys()) - - # Label from loss_fn goes after model input - if self.loss_fn: - ordered_input_list = [*ordered_input_list, list(sig_loss.parameters.keys())[1]] - - class CombineTorchModelLossFnWrapInput(torch.nn.Module): - def __init__(self, model, loss_fn, input_names): - super().__init__() - self.model = model - self.loss_fn = loss_fn - self.input_names = input_names - - def forward(self, *inputs): - sig = signature(self.model.forward) - - input_dict = {} - for key in sig.parameters: - if key in self.input_names: - input_dict[key] = inputs[self.input_names.index(key)] - - model_out = self.model(**input_dict) - if self.loss_fn is None: - return model_out - - label = inputs[-1] - preds = model_out - return self.loss_fn(preds, label), preds - - model = CombineTorchModelLossFnWrapInput(self._torch_model, self.loss_fn, input_names) - - # Do an inference to grab output types - model.eval() - with torch.no_grad(): - # Deepcopy inputs, since input values may change after model run. - sample_inputs_copy = copy.deepcopy(sample_inputs) - try: - # Deepcopy model, in case model is stateful and changes after model run. - model_copy = copy.deepcopy(model) - except Exception: - model_copy = model - warnings.warn( - "This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX." - " Compute will continue, but unexpected results may occur!" - ) - sample_outputs = model_copy(*sample_inputs_copy) - self.torch_sample_outputs = sample_outputs - model.train() - - if isinstance(sample_outputs, torch.Tensor): - sample_outputs = [sample_outputs] - - # Append 'dtype' for model description's inputs/outputs - for idx_i, sample_input in enumerate(sample_inputs): - if idx_i < len(self.model_desc.inputs): - self.model_desc.add_type_to_input_description(idx_i, sample_input.dtype) - for idx_o, sample_output in enumerate(sample_outputs): - if idx_o < len(self.model_desc.outputs): - self.model_desc.add_type_to_output_description(idx_o, sample_output.dtype) - - # Export the model to ONNX - f = io.BytesIO() - - # Deepcopy inputs, since input values may change after model run. - sample_inputs_copy = copy.deepcopy(sample_inputs) - - # Handle contrib OPs support - from onnxruntime.tools import pytorch_export_contrib_ops - - if self.options._internal_use.enable_onnx_contrib_ops: - pytorch_export_contrib_ops.register() - else: - # Unregister in case they were registered in previous calls. - pytorch_export_contrib_ops.unregister() - - # Export torch.nn.Module to ONNX - torch.onnx.export( - model, - tuple(sample_inputs_copy), - f, - input_names=[input.name for input in self.model_desc.inputs], - output_names=[output.name for output in self.model_desc.outputs], - opset_version=self.options._internal_use.onnx_opset_version, - dynamic_axes=dynamic_axes, - do_constant_folding=False, - training=torch.onnx.TrainingMode.TRAINING, - ) - onnx_model = onnx.load_model_from_string(f.getvalue()) - - # Remove 'model.' prefix introduced by CombineTorchModelLossFn class - if isinstance(model, CombineTorchModelLossFnWrapInput): - replace_name_dict = {} - for n in onnx_model.graph.initializer: - if n.name.startswith("model."): - replace_name_dict[n.name] = n.name[len("model.") :] - n.name = replace_name_dict[n.name] - for n in onnx_model.graph.node: - for i, name in enumerate(n.input): - if name in replace_name_dict: - n.input[i] = replace_name_dict[name] - - return onnx_model - - def _create_ort_training_session(self, optimizer_state_dict=None, session_options=None, provider_options=None): - if optimizer_state_dict is None: - optimizer_state_dict = {} - # Validating frozen_weights names - unused_frozen_weights = [ - n - for n in self.options.utils.frozen_weights - if n not in [i.name for i in self._onnx_model.graph.initializer] - ] - if unused_frozen_weights: - raise RuntimeError(f"{unused_frozen_weights} params from 'frozen_weights' not found in the ONNX model.") - - # Get loss name from model description - loss_name = [item.name for item in self.model_desc.outputs if item.is_loss] - assert len(loss_name) == 1, f"Only one loss output is supported ({len(loss_name)} were specified)" - loss_name = loss_name[0] - - # Parse optimizer parameters - optimizer_attributes_map = {} - optimizer_int_attributes_map = {} - trainable_params = set() - for initializer in self._onnx_model.graph.initializer: - if initializer.name in self.options.utils.frozen_weights: - continue # only trainable parameters are passed to the backend - trainable_params.add(initializer.name) - optimizer_attributes_map[initializer.name] = {} - optimizer_int_attributes_map[initializer.name] = {} - not_in_param_groups = True - for param_group in self.optim_config.params: - if initializer.name not in param_group["params"]: - continue # keep looking for a matching param_group - not_in_param_groups = False - for k, v in param_group.items(): - # 'params' is not a hyper parameter, skip it. 'lr' per weight is not supported - if k == "params" or k == "lr": - continue - if isinstance(v, float): - optimizer_attributes_map[initializer.name][k] = v - elif isinstance(v, int): - optimizer_int_attributes_map[initializer.name][k] = v - else: - raise ValueError("Optimizer attributes must be either float or int.") - - # set default values for params not found in groups - if not_in_param_groups: - for k, v in self.optim_config.defaults.items(): - if k == "lr": - continue - if isinstance(v, float): - optimizer_attributes_map[initializer.name][k] = v - elif isinstance(v, int): - optimizer_int_attributes_map[initializer.name][k] = v - else: - raise ValueError("Optimizer attributes must be either float or int.") - - self.options.distributed.horizontal_parallel_size = max(self.options.distributed.horizontal_parallel_size, 1) - self.options.distributed.data_parallel_size = ( - self.options.distributed.world_size // self.options.distributed.horizontal_parallel_size - ) - - # TrainingParameters - ort_parameters = ort.TrainingParameters() - ort_parameters.loss_output_name = loss_name - ort_parameters.use_mixed_precision = self.options.mixed_precision.enabled - ort_parameters.world_rank = self.options.distributed.world_rank - ort_parameters.world_size = self.options.distributed.world_size - ort_parameters.gradient_accumulation_steps = self.options.batch.gradient_accumulation_steps - ort_parameters.allreduce_post_accumulation = self.options.distributed.allreduce_post_accumulation - ort_parameters.enable_adasum = self.options.distributed.enable_adasum - ort_parameters.deepspeed_zero_stage = self.options.distributed.deepspeed_zero_optimization.stage - ort_parameters.enable_grad_norm_clip = self.options.utils.grad_norm_clip - ort_parameters.set_gradients_as_graph_outputs = False - ort_parameters.use_memory_efficient_gradient = self.options.utils.memory_efficient_gradient - ort_parameters.training_optimizer_name = self.optim_config.name - ort_parameters.lr_params_feed_name = self.model_desc.learning_rate.name - ort_parameters.weights_to_train = trainable_params - ort_parameters.optimizer_attributes_map = optimizer_attributes_map - ort_parameters.optimizer_int_attributes_map = optimizer_int_attributes_map - if bool(optimizer_state_dict): - ort_parameters.set_optimizer_initial_state(optimizer_state_dict) - - ort_parameters.attn_dropout_recompute = self.options.graph_transformer.attn_dropout_recompute - ort_parameters.gelu_recompute = self.options.graph_transformer.gelu_recompute - ort_parameters.transformer_layer_recompute = self.options.graph_transformer.transformer_layer_recompute - ort_parameters.number_recompute_layers = self.options.graph_transformer.number_recompute_layers - - ort_parameters.data_parallel_size = self.options.distributed.data_parallel_size - ort_parameters.horizontal_parallel_size = self.options.distributed.horizontal_parallel_size - ort_parameters.pipeline_parallel_size = self.options.distributed.pipeline_parallel.pipeline_parallel_size - ort_parameters.num_pipeline_micro_batches = ( - self.options.distributed.pipeline_parallel.num_pipeline_micro_batches - ) - ort_parameters.pipeline_cut_info_string = self.options.distributed.pipeline_parallel.pipeline_cut_info_string - # We have special handling for dictionary-typed option. - # sliced_schema._validated_opts is the original dictionary while sliced_schema is a _ORTTrainerOptionsInternal. - ort_parameters.sliced_schema = self.options.distributed.pipeline_parallel.sliced_schema._validated_opts - # We have special handling for dictionary-typed option. - # sliced_axes._validated_opts is the original dictionary while sliced_schema is a _ORTTrainerOptionsInternal. - ort_parameters.sliced_axes = self.options.distributed.pipeline_parallel.sliced_axes._validated_opts - ort_parameters.sliced_tensor_names = self.options.distributed.pipeline_parallel.sliced_tensor_names - - ort_parameters.model_after_graph_transforms_path = ( - self.options.debug.graph_save_paths.model_after_graph_transforms_path - ) - ort_parameters.model_with_gradient_graph_path = ( - self.options.debug.graph_save_paths.model_with_gradient_graph_path - ) - ort_parameters.model_with_training_graph_path = ( - self.options.debug.graph_save_paths.model_with_training_graph_path - ) - - # SessionOptions - session_options = ort.SessionOptions() if session_options is None else session_options - session_options.use_deterministic_compute = self.options.debug.deterministic_compute - if ( - self.options.graph_transformer.attn_dropout_recompute - or self.options.graph_transformer.gelu_recompute - or self.options.graph_transformer.transformer_layer_recompute - ): - session_options.execution_order = ort.ExecutionOrder.PRIORITY_BASED - if len(self.options.debug.graph_save_paths.model_with_training_graph_after_optimization_path) > 0: - session_options.optimized_model_filepath = ( - self.options.debug.graph_save_paths.model_with_training_graph_after_optimization_path - ) - - # old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error. - # for example, load_state_dict will be called before returing the function, and it calls _init_session again - del self._training_session - - # Set provider-specific options if needed - def get_providers(provider_options): - providers = ort.get_available_providers() - if provider_options: - for provider_name in provider_options: - if provider_name in providers: - providers[providers.index(provider_name)] = (provider_name, provider_options[provider_name]) - else: - providers.insert(0, (provider_name, provider_options[provider_name])) - # default: using cuda - elif "cuda" in self.options.device.id.lower(): - gpu_ep_options = {"device_id": _utils.get_device_index(self.options.device.id)} - gpu_ep_name = "ROCMExecutionProvider" if self.is_rocm_pytorch else "CUDAExecutionProvider" - if self.options.device.mem_limit > 0: - gpu_ep_options["gpu_mem_limit"] = self.options.device.mem_limit - - if gpu_ep_name not in providers: - raise RuntimeError( - "ORTTrainer options specify a CUDA device but the {} provider is unavailable.".format( - cuda_ep_name # noqa: F821 - ) - ) - - providers[providers.index(gpu_ep_name)] = (gpu_ep_name, gpu_ep_options) - - return providers - - # TrainingSession - self._training_session = ort.TrainingSession( - self._onnx_model.SerializeToString(), ort_parameters, session_options, get_providers(provider_options) - ) - - # I/O bindings - self._train_io_binding = self._training_session.io_binding() - self._eval_io_binding = self._training_session.io_binding() - - def _init_onnx_model(self, inputs): - if self._onnx_model is not None: - return - - if self._torch_model is not None: - # PyTorch model is moved to cpu to save GPU memory - self._torch_model.cpu() - - # PyTorch buffers (created using 'register_buffer') shouldn't be trained - torch_buffers = list(dict(self._torch_model.named_buffers()).keys()) - self.options.utils.frozen_weights.extend(torch_buffers) - - # Export to ONNX - self._onnx_model = self._convert_torch_model_loss_fn_to_onnx(inputs, "cpu") - - # Post processing for ONNX models expported from PyTorch - if self.options._internal_use.enable_internal_postprocess: - self._onnx_model = postprocess.run_postprocess(self._onnx_model) - if self.options._internal_use.extra_postprocess: - self._onnx_model = self.options._internal_use.extra_postprocess(self._onnx_model) - - optimizer_state_dict = {} - if self._load_state_dict: - optimizer_state_dict = self._load_state_dict() - - self._init_session( - optimizer_state_dict, - session_options=self.options.session_options, - provider_options=self.options._validated_opts["provider_options"], - ) - - def _init_session(self, optimizer_state_dict={}, session_options=None, provider_options=None): # noqa: B006 - if self._onnx_model is None: - return - - if self.options.utils.run_symbolic_shape_infer: - self._onnx_model = SymbolicShapeInference.infer_shapes( - self._onnx_model, auto_merge=True, guess_output_rank=True - ) - - # Create training session used by train_step - # pass all optimizer states to the backend - self._create_ort_training_session( - optimizer_state_dict, session_options=session_options, provider_options=provider_options - ) - - # Update model description to update dtype when mixed precision is enabled - # C++ backend modifies model's output dtype from float32 to float16 for mixed precision - # Note that for training we must use float32 and for evaluation we must use float16 - for idx, o_desc in enumerate(self.model_desc.outputs): - if ( - self.options.mixed_precision.enabled - and o_desc.dtype == torch.float32 - and not self._training_session.is_output_fp32_node(o_desc.name) - ): - self.model_desc.add_type_to_output_description(idx, o_desc.dtype, torch.float16) - - # Update model description - self._model_desc_inputs_with_lr = [*self.model_desc.inputs, self.model_desc.learning_rate] - - # Update Mixed Precision, if applicable - if self.options.mixed_precision.enabled: - self.model_desc.loss_scale_input = self._training_session.loss_scale_input_name - self._model_desc_inputs_with_lr_and_loss_scale = [ - *self._model_desc_inputs_with_lr, - self.model_desc.loss_scale_input, - ] - self.model_desc.all_finite = _utils.get_all_gradients_finite_name_from_session(self._training_session) - self._model_desc_outputs_with_all_finite = [*self.model_desc.outputs, self.model_desc.all_finite] - elif self.options.mixed_precision.loss_scaler: - raise ValueError("Loss Scaler cannot be specified when Mixed Precision is not enabled") - - # Update Loss Scaler Input Name, if applicable - if self.options.mixed_precision.enabled and self.options.mixed_precision.loss_scaler: - self.options.mixed_precision.loss_scaler.input_name = self.model_desc.loss_scale_input.name - elif not self.options.mixed_precision.enabled and self.options.mixed_precision.loss_scaler: - raise ValueError("Loss Scaler cannot be specified when Mixed Precision is not enabled") - - # Update Gradient Accumulation, if applicable - if self.options.batch.gradient_accumulation_steps > 1: - self.model_desc.gradient_accumulation = _utils.get_gradient_accumulation_name_from_session( - self._training_session - ) - self._model_desc_outputs_with_gradient_accumulation = [ - *self.model_desc.outputs, - self.model_desc.gradient_accumulation, - ] - - # TODO: Remove when experimental checkpoint functions are removed - if self._state_dict: - checkpoint.experimental_load_state_dict(self, self._state_dict, self._load_state_dict_strict) - self._state_dict_debug = self._state_dict - self._state_dict = {} - - def _prepare_model_input(self, inputs_desc, lr, loss_scale, *inputs, **kwargs): - # Normalize input to tuple of samples - if type(inputs) == tuple and len(inputs) == 1 and type(inputs[0]) == list: # noqa: E721 - input = tuple(inputs[0]) - else: - input = inputs - - # Append input from 'kwargs' - for input_desc in inputs_desc: - if input_desc.name in kwargs: - input = (*input, kwargs[input_desc.name]) - - # Append learning rate - extra_inputs = 0 - if lr is not None: - lr = torch.tensor([lr]) - input += (lr,) - extra_inputs += 1 - - # Append loss scale - if loss_scale is not None: - assert self.options.mixed_precision.enabled, "Loss scale cannot be used without mixed precision" - loss_scale = torch.tensor([loss_scale]) - input += (loss_scale,) - extra_inputs += 1 - - # Only assert length of input when fetches is not used - assert self._train_step_info.fetches or len(self.model_desc.inputs) + extra_inputs == len(input) - return input - - def _resolve_symbolic_dimensions(self, inputs, inputs_desc, outputs_desc): - outputs = copy.deepcopy(outputs_desc) - resolved_dims = {} - for input, i_desc in zip(inputs, inputs_desc): - for i_idx, i_axis in enumerate(i_desc.shape): - if isinstance(i_axis, str): - if i_axis not in resolved_dims: - resolved_dims[i_axis] = input.size()[i_idx] - else: - assert resolved_dims[i_axis] == input.size()[i_idx], f"Mismatch in dynamic shape {i_axis}" - - for o_desc in outputs: - for idx_o, o_axis in enumerate(o_desc.shape): - if isinstance(o_axis, str): - o_desc.shape[idx_o] = resolved_dims[o_axis] - - unknown_dim = [o_desc.name for dim in o_desc.shape for o_desc in outputs if isinstance(dim, str)] - if unknown_dim: - raise RuntimeError(f"Cannot execute model with unknown output dimensions ({unknown_dim}") - - return outputs - - def _training_session_run_helper(self, is_train, inputs, inputs_desc, outputs_desc, run_options=None): - # Select IO binding - if is_train: - iobinding = self._train_io_binding - else: - iobinding = self._eval_io_binding - - # Get the list of the actual session inputs because unused inputs can be removed. - input_nodes = self._training_session.get_inputs() - input_node_names = [input_node.name for input_node in input_nodes] - - # Bind input tensors - for input, input_desc in zip(inputs, inputs_desc): - if input_desc.name in input_node_names: - device_index = _utils.get_device_index_from_input(input) - iobinding.bind_input( - input_desc.name, - input.device.type, - device_index, - _utils.dtype_torch_to_numpy(input.dtype), - list(input.size()), - input.data_ptr(), - ) - - # Bind output tensors - outputs_desc_resolved = self._resolve_symbolic_dimensions(inputs, inputs_desc, outputs_desc) - result = {} - for output_desc in outputs_desc_resolved: - target_device = self.options.device.id - if self.options.mixed_precision.enabled and output_desc.name == self.model_desc.all_finite.name: - # Keep all finite flag on CPU to match backend implementation - # This prevents CPU -> GPU -> CPU copies between frontend and backend - target_device = "cpu" - # the self.options.device may be a device that pytorch does not recognize. - # in that case, we temporary prefer to leave the input/output on CPU and let ORT session - # to move the data between device and host. - # so output will be on the same device as input. - try: - torch.device(target_device) - except Exception: - # in this case, input/output must on CPU - assert input.device.type == "cpu" - target_device = "cpu" - - torch_tensor = torch.zeros( - output_desc.shape, - device=target_device, - dtype=output_desc.dtype_amp if output_desc.dtype_amp else output_desc.dtype, - ) - iobinding.bind_output( - output_desc.name, - torch_tensor.device.type, - _utils.get_device_index(target_device), - _utils.dtype_torch_to_numpy(torch_tensor.dtype), - list(torch_tensor.size()), - torch_tensor.data_ptr(), - ) - result[output_desc.name] = torch_tensor - - # Run a train/eval step - self._training_session.run_with_iobinding(iobinding, run_options) - return result - - def _update_onnx_model_initializers(self, state_tensors): - r"""Updates ONNX graph initializers with state_tensors's values - - Usually called to save or load an ONNX model. - - The tensors names of state_tensors are compared to all ONNX initializer tensors - and when the name matches, the ONNX graph is updated with the new value. - """ - assert isinstance(state_tensors, dict), "state_tensors must be a dict" - - new_weights = [] - replace_indices = [] - for i, w in enumerate(self._onnx_model.graph.initializer): - if w.name in state_tensors: - new_weights.append(onnx.numpy_helper.from_array(state_tensors[w.name], w.name)) - replace_indices.append(i) - replace_indices.sort(reverse=True) - for w_i in replace_indices: - del self._onnx_model.graph.initializer[w_i] - self._onnx_model.graph.initializer.extend(new_weights) - - def _extract_model_states(self, state_dict, pytorch_format): - """Extract model states from the training session and load into the state_dict""" - - model_states = self._training_session.get_model_state(include_mixed_precision_weights=False) - state_dict[_utils.state_dict_model_key()] = {} - - # extract trained model weights from the training session - for precision in model_states: - state_dict[_utils.state_dict_model_key()][precision] = {} - for model_state_key in model_states[precision]: - if pytorch_format: - state_dict[_utils.state_dict_model_key()][precision][model_state_key] = torch.from_numpy( - model_states[precision][model_state_key] - ) - else: - state_dict[_utils.state_dict_model_key()][precision][model_state_key] = model_states[precision][ - model_state_key - ] - - # extract untrained (frozen) model weights - for node in self._onnx_model.graph.initializer: - if ( - node.name not in state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()] - and node.name in self.options.utils.frozen_weights - ): - if pytorch_format: - state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()][ - node.name - ] = torch.from_numpy(onnx.numpy_helper.to_array(node)) - else: - state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()][ - node.name - ] = onnx.numpy_helper.to_array(node) - - def _extract_trainer_options(self, state_dict): - """Extract relevant trainer configuration and load it into the state_dict""" - - mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key() - zero_stage = _utils.state_dict_trainer_options_zero_stage_key() - world_rank = _utils.state_dict_trainer_options_world_rank_key() - world_size = _utils.state_dict_trainer_options_world_size_key() - optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key() - D_size = _utils.state_dict_trainer_options_data_parallel_size_key() # noqa: N806 - H_size = _utils.state_dict_trainer_options_horizontal_parallel_size_key() # noqa: N806 - - state_dict[_utils.state_dict_trainer_options_key()] = {} - state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] = self.options.mixed_precision.enabled - state_dict[_utils.state_dict_trainer_options_key()][ - zero_stage - ] = self.options.distributed.deepspeed_zero_optimization.stage - state_dict[_utils.state_dict_trainer_options_key()][world_rank] = self.options.distributed.world_rank - state_dict[_utils.state_dict_trainer_options_key()][world_size] = self.options.distributed.world_size - state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] = self.optim_config.name - state_dict[_utils.state_dict_trainer_options_key()][D_size] = self.options.distributed.data_parallel_size - state_dict[_utils.state_dict_trainer_options_key()][H_size] = self.options.distributed.horizontal_parallel_size - - def _extract_train_step_info(self, state_dict): - """Extract train step info settings and save it into the state_dict""" - - optimization_step = _utils.state_dict_train_step_info_optimization_step_key() - step = _utils.state_dict_train_step_info_step_key() - - state_dict[_utils.state_dict_train_step_info_key()] = {} - state_dict[_utils.state_dict_train_step_info_key()][optimization_step] = self._train_step_info.optimization_step - state_dict[_utils.state_dict_train_step_info_key()][step] = self._train_step_info.step - - def state_dict(self, pytorch_format=False): - """Returns a dictionary with model, train step info and optionally, optimizer states - - The returned dictionary contains the following information: - - Model and optimizer states - - Required ORTTrainerOptions settings - - Distributed training information, such as but not limited to ZeRO - - Train step info settings - - Structure of the returned dictionary: - - When `pytorch_format = False` - schema: - { - "model": - { - type: dict, - schema: - { - "full_precision": - { - type: dict, - schema: - { - model_weight_name: - { - type: array - } - } - } - } - }, - "optimizer": - { - type: dict, - schema: - { - model_weight_name: - { - type: dict, - schema: - { - "Moment_1": - { - type: array - }, - "Moment_2": - { - type: array - }, - "Update_Count": - { - type: array, - optional: True # present if optimizer is adam, absent otherwise - } - } - }, - "shared_optimizer_state": - { - type: dict, - optional: True, # present optimizer is shared, absent otherwise. - schema: - { - "step": - { - type: array, - } - } - } - } - }, - "trainer_options": - { - type: dict, - schema: - { - "mixed_precision": - { - type: bool - }, - "zero_stage": - { - type: int - }, - "world_rank": - { - type: int - }, - "world_size": - { - type: int - }, - "optimizer_name": - { - type: str - }, - "data_parallel_size": - { - type: int - }, - "horizontal_parallel_size": - { - type: int - } - } - }, - "partition_info": - { - type: dict, - optional: True, # present if states partitioned, else absent - schema: - { - model_weight_name: - { - type: dict, - schema: - { - "original_dim": - { - type: array - }, - "megatron_row_partition": - { - type: int - } - } - } - } - }, - "train_step_info": - { - type: dict, - schema: - { - "optimization_step": - { - type: int - }, - "step": - { - type: int - } - } - } - } - - When `pytorch_format = True` - schema: - { - model_weight_name: - { - type: tensor - } - } - - Args: - pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema - - Returns: - A dictionary with `ORTTrainer` state - """ - if not self._training_session: - warnings.warn( - "ONNX Runtime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling ORTTrainer.state_dict().", - UserWarning, - ) - return self._load_state_dict.args[0] if self._load_state_dict else {} - - state_dict = {} - - # load training session model states into the state_dict - self._extract_model_states(state_dict, pytorch_format) - if pytorch_format: - if self.options.distributed.deepspeed_zero_optimization.stage > 0: - warnings.warn("Incomplete state_dict: ZeRO enabled", UserWarning) - if self.options.distributed.horizontal_parallel_size > 1: - warnings.warn("Incomplete state_dict: Megatron enabled", UserWarning) - # if pytorch_format is true, return a flat dictionary with only model states - # which is compatible with a PyTorch model - return state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()] - - # load training session optimizer states into the state_dict - state_dict[_utils.state_dict_optimizer_key()] = self._training_session.get_optimizer_state() - - # extract the relevant training configuration from the trainer and load them into the state_dict - self._extract_trainer_options(state_dict) - - # Extract train step info settings and load it into the state_dict - self._extract_train_step_info(state_dict) - - # add partition information in case of a distributed run - if ( - self.options.distributed.deepspeed_zero_optimization.stage > 0 - or self.options.distributed.horizontal_parallel_size > 1 - ): - state_dict[_utils.state_dict_partition_info_key()] = self._training_session.get_partition_info_map() - - return state_dict - - def _load_model_states(self, state_dict, strict): - """Load the model states onto the onnx model graph""" - - if _utils.state_dict_model_key() not in state_dict: - return - - # collect all initializer names from the current onnx graph - assert self._onnx_model, "ONNX model graph is not exported" - initializer_names = {node.name for node in self._onnx_model.graph.initializer} - - # loaded_initializers dict will be loaded with all the model states from the state dictionary - # that are found in the initializer_names dictionary - loaded_initializers = {} - - # copy over model states from the input state dict onto the onnx model - for precision, precision_states in state_dict[_utils.state_dict_model_key()].items(): - for state_key, state_value in precision_states.items(): - if state_key in initializer_names: - loaded_initializers[state_key] = state_value - elif strict: - raise RuntimeError(f"Unexpected key: {state_key} in state_dict[model][{precision}]") - - # update onnx model from loaded initializers - self._update_onnx_model_initializers(loaded_initializers) - - def _load_optimizer_states(self, current_state_dict, state_dict): - """Load the optimizer states onto the training session state dictionary""" - - def _check_optimizer_mismatch(state_dict): - """Assert that the loaded optimizer has the same config as the current training session config""" - - # the state_dict optimizer_name can be a byte string (if coming from checkpoint file) - # or can be a regular string (coming from user) - optimizer_name = state_dict[_utils.state_dict_trainer_options_key()][ - _utils.state_dict_trainer_options_optimizer_name_key() - ] - - # optimizer_name can be either a regular string or a byte string. - # if it is a byte string, convert to regular string using decode() - # if it is a regular string, do nothing to it - try: # noqa: SIM105 - optimizer_name = optimizer_name.decode() - except AttributeError: - pass - assert self.optim_config.name == optimizer_name, "Optimizer mismatch: expected {}, got {}".format( - self.optim_config.name, optimizer_name - ) - - if _utils.state_dict_optimizer_key() not in state_dict: - return - - # check optimizer config names are the same for current session and the sessino being loaded - _check_optimizer_mismatch(state_dict) - - # create an entry for the optimizer in the training session state dictionary - if _utils.state_dict_optimizer_key() not in current_state_dict: - current_state_dict[_utils.state_dict_optimizer_key()] = {} - - # copy over optimizer states from the input state dict onto the training session state dict - for model_state_key, optimizer_dict in state_dict[_utils.state_dict_optimizer_key()].items(): - if model_state_key not in current_state_dict[_utils.state_dict_optimizer_key()]: - current_state_dict[_utils.state_dict_optimizer_key()][model_state_key] = {} - for optimizer_state_key, optimizer_state_value in optimizer_dict.items(): - current_state_dict[_utils.state_dict_optimizer_key()][model_state_key][ - optimizer_state_key - ] = optimizer_state_value - - def _load_state_dict_impl(self, state_dict, strict=True): - """Load the state dictionary onto the onnx model and on the training session graph""" - - # clear the callable partial - self._load_state_dict = None - - def _mismatch_keys(keys1, keys2, in_error_str, allow_unexpected=False): - """Find out the missing and the unexpected keys in two dictionaries - - Throws a runtime error if missing or unexpected keys are found - - Keys in keys1 not in keys2 will be marked as missing - - Keys in keys2 not in keys1 will be marked as unexpected - """ - keys1 = set(keys1) - keys2 = set(keys2) - missing_keys = list(keys1 - keys2) - unexpected_keys = list(keys2 - keys1) - if len(missing_keys) > 0: - raise RuntimeError(f"Missing keys: {missing_keys} in {in_error_str}") - if len(unexpected_keys) > 0 and not allow_unexpected: - raise RuntimeError(f"Unexpected keys: {unexpected_keys} in {in_error_str}") - - def _check_model_key_mismatch(current_state_dict, state_dict, allow_unexpected=False): - """Check if there is any mismatch in the model sub state dictionary between the two state_dicts""" - - # check unxexpected and missing precision keys in the model state_dict compared to the training - # session model state_dict - _mismatch_keys( - current_state_dict[_utils.state_dict_model_key()], - state_dict[_utils.state_dict_model_key()], - "state_dict[model]", - allow_unexpected, - ) - - # check for model state key mismatch - for precision_key in current_state_dict[_utils.state_dict_model_key()]: - _mismatch_keys( - current_state_dict[_utils.state_dict_model_key()][precision_key], - state_dict[_utils.state_dict_model_key()][precision_key], - f"state_dict[model][{precision_key}]", - allow_unexpected, - ) - - def _check_optimizer_key_mismatch(current_state_dict, state_dict, allow_unexpected=False): - """Check if there is any mismatch in the optimizer sub state dictionary between the two state_dicts""" - - # check for model state key mismatch for the optimizer state_dict - _mismatch_keys( - current_state_dict[_utils.state_dict_optimizer_key()], - state_dict[_utils.state_dict_optimizer_key()], - "state_dict[optimizer]", - allow_unexpected, - ) - - # check for optimizer state keys mismatch - for model_state_key in current_state_dict[_utils.state_dict_optimizer_key()]: - _mismatch_keys( - current_state_dict[_utils.state_dict_optimizer_key()][model_state_key], - state_dict[_utils.state_dict_optimizer_key()][model_state_key], - f"state_dict[optimizer][{model_state_key}]", - allow_unexpected, - ) - - def _check_key_mismatch(current_state_dict, state_dict, allow_unexpected=False): - """Check if there is a mismatch in the keys (model and optimizer) in the two state_dicts""" - - # check presence of 'model' in the input state_dict - if _utils.state_dict_model_key() in state_dict: - _check_model_key_mismatch(current_state_dict, state_dict, allow_unexpected) - else: - warnings.warn("Missing key: model in state_dict", UserWarning) - # check presence of 'optimizer' in the input state_dict - if _utils.state_dict_optimizer_key() in state_dict: - _check_optimizer_key_mismatch(current_state_dict, state_dict, allow_unexpected) - else: - warnings.warn("Missing key: optimizer in state_dict", UserWarning) - - # extract state dict from the current training session. this is to persist the states between - # two training sessions. - # for example, if user provided only the model states, the optimizer states from the current - # training session must be persisted - current_state_dict = {} - if self._training_session: - current_state_dict = self.state_dict() - if strict: - # for Zero enabled, the current trainer might not have the complete state, and we must allow - # extra keys to be present in the state dict - allow_unexpected = self.options.distributed.deepspeed_zero_optimization.stage > 0 - _check_key_mismatch(current_state_dict, state_dict, allow_unexpected) - - # load the model states from the input state dictionary into the onnx graph - self._load_model_states(state_dict, strict) - - # load the optimizer states from the input state dictionary into the training session states - # dictionary - self._load_optimizer_states(current_state_dict, state_dict) - - return ( - current_state_dict[_utils.state_dict_optimizer_key()] - if _utils.state_dict_optimizer_key() in current_state_dict - else {} - ) - - def _load_train_step_info(self, state_dict): - """Load the train step info settings from state dict""" - - if _utils.state_dict_train_step_info_key() not in state_dict: - warnings.warn("Missing key: train_step_info in state_dict", UserWarning) - return - - optimization_step = _utils.state_dict_train_step_info_optimization_step_key() - step = _utils.state_dict_train_step_info_step_key() - - self._train_step_info.optimization_step = state_dict[_utils.state_dict_train_step_info_key()][optimization_step] - self._train_step_info.step = state_dict[_utils.state_dict_train_step_info_key()][step] - - def load_state_dict(self, state_dict, strict=True): - """Loads state_dict containing model/optimizer states into ORTTrainer - - The state_dict dictionary may contain the following information: - - Model and optimizer states - - Required ORTTrainerOptions settings - - Distributed training information, such as but not limited to ZeRO - - Args: - state_dict: state dictionary containing both model and optimizer states. The structure of this dictionary - should be the same as the one that is returned by ORTTrainer.state_dict for the case when pytorch_format=False - strict: boolean flag to strictly enforce that the input state_dict keys match the keys from ORTTrainer.state_dict - """ - - # if onnx graph has not been initialized, loading of states will be put on hold. - # a copy of the state_dict and other arguments to the function will be stored until the onnx graph has - # been initialized. Once the graph is initialized, the desired states will be loaded onto the grpah - if not self._training_session: - self._load_state_dict = partial(self._load_state_dict_impl, state_dict, strict=strict) - return - - # load the train step info settings - self._load_train_step_info(state_dict) - - # load states onto the frontend onnx graph - optimizer_state_dict = self._load_state_dict_impl(state_dict, strict=strict) - - # create a new training session after loading initializer states onto the onnx graph - # pass the populated states to the training session to populate the backend graph - self._init_session( - optimizer_state_dict, - session_options=self.options.session_options, - provider_options=self.options._validated_opts["provider_options"], - ) - - def save_checkpoint(self, path, user_dict={}, include_optimizer_states=True): # noqa: B006 - """Persists ORTTrainer state dictionary on disk along with user_dict. - - Saves the state_dict along with the user_dict to a file specified by path. - - Args: - path: string representation to a file path or a python file-like object. - if file already exists at path, an exception is raised. - user_dict: custom data to be saved along with the state_dict. This data will be returned - to the user when load_checkpoint is called. - include_optimizer_states: boolean flag indicating whether or not to persist the optimizer states. - on load_checkpoint, only model states will be loaded if include_optimizer_states==True - """ - - # extract state_dict to be saved in the checkpoint - state_dict = self.state_dict() - - # if user_dict is provided, serialize to bytes and convert to hex string. - # this helps in loading the types as they are given by the user since hdf5 - # converts to numpy types otherwise - if bool(user_dict): - state_dict[_utils.state_dict_user_dict_key()] = _checkpoint_storage.to_serialized_hex(user_dict) - - # if include_optimizer_states is False, only save the model states in the checkpoint file - if not include_optimizer_states: - if _utils.state_dict_optimizer_key() in state_dict: - del state_dict[_utils.state_dict_optimizer_key()] - - _checkpoint_storage.save(state_dict, path) - - def _aggregation_required(self, loaded_trainer_options): - """Checks if aggregation is required for the loading the state_dict into the ORTTrainer""" - - # To load states in the backend, aggregation is required for every ZeRO - # or Megatron checkpoint - return ( - loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 - or loaded_trainer_options[_utils.state_dict_trainer_options_horizontal_parallel_size_key()] > 1 - ) - - def load_checkpoint(self, *paths, strict=True): - """Loads the saved checkpoint state dictionary into the ORTTrainer - - Reads the saved checkpoint files specified by paths from disk and loads the state dictionary - onto the ORTTrainer. - Aggregates the checkpoint files if aggregation is required. - - Args: - paths: one or more files represented as strings where the checkpoint is saved - strict: boolean flag to strictly enforce that the saved checkpoint state_dict - keys match the keys from ORTTrainer.state_dict - Returns: - dictionary that the user had saved when calling save_checkpoint - """ - state_dict = {} - - # check if aggregation is required - loaded_trainer_options = _checkpoint_storage.load(paths[0], key=_utils.state_dict_trainer_options_key()) - if self._aggregation_required(loaded_trainer_options): - # if aggregation is required, aggregation logic must be run on the saved checkpoints - state_dict = checkpoint.aggregate_checkpoints(paths, pytorch_format=False) - else: - # if aggregation is not required, there must only be a single file that needs to be loaded - assert len(paths) == 1, f"Expected number of files to load: 1, got {len(paths)}" - state_dict = _checkpoint_storage.load(paths[0]) - - # extract user dict from the saved checkpoint - user_dict = {} - if _utils.state_dict_user_dict_key() in state_dict: - user_dict = _checkpoint_storage.from_serialized_hex(state_dict[_utils.state_dict_user_dict_key()]) - del state_dict[_utils.state_dict_user_dict_key()] - - self.load_state_dict(state_dict, strict=strict) - - return user_dict diff --git a/orttraining/orttraining/python/training/orttrainer_options.py b/orttraining/orttraining/python/training/orttrainer_options.py deleted file mode 100644 index c63ac6f82c87f..0000000000000 --- a/orttraining/orttraining/python/training/orttrainer_options.py +++ /dev/null @@ -1,692 +0,0 @@ -import cerberus - -import onnxruntime as ort -from onnxruntime.capi._pybind_state import PropagateCastOpsStrategy - -from .amp import loss_scaler -from .optim import lr_scheduler - - -class ORTTrainerOptions: - r"""Settings used by ONNX Runtime training backend - - The parameters are hierarchically organized to facilitate configuration through semantic groups - that encompasses features, such as distributed training, etc. - - Input validation is performed on the input dict during instantiation to ensure - that supported parameters and values are passed in. Invalid input results - in :py:obj:`ValueError` exception with details on it. - - Args: - options (dict): contains all training options - _validate (bool, default is True): for internal use only - - Supported schema for kwargs: - - .. code-block:: python - - schema = { - 'batch' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'gradient_accumulation_steps' : { - 'type' : 'integer', - 'min' : 1, - 'default' : 1 - } - }, - }, - 'device' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'id' : { - 'type' : 'string', - 'default' : 'cuda' - }, - 'mem_limit' : { - 'type' : 'integer', - 'min' : 0, - 'default' : 0 - } - } - }, - 'distributed': { - 'type': 'dict', - 'default': {}, - 'required': False, - 'schema': { - 'world_rank': { - 'type': 'integer', - 'min': 0, - 'default': 0 - }, - 'world_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'local_rank': { - 'type': 'integer', - 'min': 0, - 'default': 0 - }, - 'data_parallel_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'horizontal_parallel_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'pipeline_parallel' : { - 'type': 'dict', - 'default': {}, - 'required': False, - 'schema': { - 'pipeline_parallel_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'num_pipeline_micro_batches': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'pipeline_cut_info_string': { - 'type': 'string', - 'default': '' - }, - 'sliced_schema': { - 'type': 'dict', - 'default': {}, - 'keysrules': {'type': 'string'}, - 'valuesrules': { - 'type': 'list', - 'schema': {'type': 'integer'} - } - }, - 'sliced_axes': { - 'type': 'dict', - 'default': {}, - 'keysrules': {'type': 'string'}, - 'valuesrules': {'type': 'integer'} - }, - 'sliced_tensor_names': { - 'type': 'list', - 'schema': {'type': 'string'}, - 'default': [] - } - } - }, - 'allreduce_post_accumulation': { - 'type': 'boolean', - 'default': False - }, - 'deepspeed_zero_optimization': { - 'type': 'dict', - 'default': {}, - 'required': False, - 'schema': { - 'stage': { - 'type': 'integer', - 'min': 0, - 'max': 1, - 'default': 0 - }, - } - }, - 'enable_adasum': { - 'type': 'boolean', - 'default': False - } - } - }, - 'lr_scheduler' : { - 'type' : 'optim.lr_scheduler', - 'nullable' : True, - 'default' : None - }, - 'mixed_precision' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'enabled' : { - 'type' : 'boolean', - 'default' : False - }, - 'loss_scaler' : { - 'type' : 'amp.loss_scaler', - 'nullable' : True, - 'default' : None - } - } - }, - 'graph_transformer': { - 'type': 'dict', - 'required': False, - 'default': {}, - 'schema': { - 'attn_dropout_recompute': { - 'type': 'boolean', - 'default': False - }, - 'gelu_recompute': { - 'type': 'boolean', - 'default': False - }, - 'transformer_layer_recompute': { - 'type': 'boolean', - 'default': False - }, - 'number_recompute_layers': { - 'type': 'integer', - 'min': 0, - 'default': 0 - }, - 'propagate_cast_ops_config': { - 'type': 'dict', - 'required': False, - 'default': {}, - 'schema': { - 'propagate_cast_ops_strategy': { - 'type': 'onnxruntime.training.PropagateCastOpsStrategy', - 'default': PropagateCastOpsStrategy.FLOOD_FILL - }, - 'propagate_cast_ops_level': { - 'type': 'integer', - 'default': 1 - }, - 'propagate_cast_ops_allow': { - 'type': 'list', - 'schema': {'type': 'string'}, - 'default': [] - } - } - } - } - }, - 'utils' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'frozen_weights' : { - 'type' : 'list', - 'default' : [] - }, - 'grad_norm_clip' : { - 'type' : 'boolean', - 'default' : True - }, - 'memory_efficient_gradient' : { - 'type' : 'boolean', - 'default' : False - }, - 'run_symbolic_shape_infer' : { - 'type' : 'boolean', - 'default' : False - } - } - }, - 'debug' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'deterministic_compute' : { - 'type' : 'boolean', - 'default' : False - }, - 'check_model_export' : { - 'type' : 'boolean', - 'default' : False - }, - 'graph_save_paths' : { - 'type' : 'dict', - 'default': {}, - 'required': False, - 'schema': { - 'model_after_graph_transforms_path': { - 'type': 'string', - 'default': '' - }, - 'model_with_gradient_graph_path':{ - 'type': 'string', - 'default': '' - }, - 'model_with_training_graph_path': { - 'type': 'string', - 'default': '' - }, - 'model_with_training_graph_after_optimization_path': { - 'type': 'string', - 'default': '' - }, - } - }, - } - }, - '_internal_use' : { - 'type' : 'dict', - 'required': False, - 'default' : {}, - 'schema' : { - 'enable_internal_postprocess' : { - 'type' : 'boolean', - 'default' : True - }, - 'extra_postprocess' : { - 'type' : 'callable', - 'nullable' : True, - 'default' : None - }, - 'onnx_opset_version': { - 'type': 'integer', - 'min' : 12, - 'max' :14, - 'default': 14 - }, - 'enable_onnx_contrib_ops' : { - 'type' : 'boolean', - 'default' : True - } - } - }, - 'provider_options':{ - 'type': 'dict', - 'default': {}, - 'required': False, - 'schema': {} - }, - 'session_options': { - 'type': 'SessionOptions', - 'nullable': True, - 'default': None - }, - } - - Keyword arguments: - batch (dict): - batch related settings - batch.gradient_accumulation_steps (int, default is 1): - number of steps to accumulate before do collective gradient reduction - device (dict): - compute device related settings - device.id (string, default is 'cuda'): - device to run training - device.mem_limit (int): - maximum memory size (in bytes) used by device.id - distributed (dict): - distributed training options. - distributed.world_rank (int, default is 0): - rank ID used for data/horizontal parallelism - distributed.world_size (int, default is 1): - number of ranks participating in parallelism - distributed.data_parallel_size (int, default is 1): - number of ranks participating in data parallelism - distributed.horizontal_parallel_size (int, default is 1): - number of ranks participating in horizontal parallelism - distributed.pipeline_parallel (dict): - Options which are only useful to pipeline parallel. - distributed.pipeline_parallel.pipeline_parallel_size (int, default is 1): - number of ranks participating in pipeline parallelism - distributed.pipeline_parallel.num_pipeline_micro_batches (int, default is 1): - number of micro-batches. We divide input batch into micro-batches and run the graph. - distributed.pipeline_parallel.pipeline_cut_info_string (string, default is ''): - string of cutting ids for pipeline partition. - distributed.allreduce_post_accumulation (bool, default is False): - True enables overlap of AllReduce with computation, while False, - postpone AllReduce until all gradients are ready - distributed.deepspeed_zero_optimization: - DeepSpeed ZeRO options. - distributed.deepspeed_zero_optimization.stage (int, default is 0): - select which stage of DeepSpeed ZeRO to use. Stage 0 means disabled. - distributed.enable_adasum (bool, default is False): - enable `Adasum `_ - algorithm for AllReduce - lr_scheduler (optim._LRScheduler, default is None): - specifies learning rate scheduler - mixed_precision (dict): - mixed precision training options - mixed_precision.enabled (bool, default is False): - enable mixed precision (fp16) - mixed_precision.loss_scaler (amp.LossScaler, default is None): - specifies a loss scaler to be used for fp16. If not specified, - :py:class:`.DynamicLossScaler` is used with default values. - Users can also instantiate :py:class:`.DynamicLossScaler` and - override its parameters. Lastly, a completely new implementation - can be specified by extending :py:class:`.LossScaler` class from scratch - graph_transformer (dict): - graph transformer related configurations - graph_transformer.attn_dropout_recompute(bool, default False) - graph_transformer.gelu_recompute(bool, default False) - graph_transformer.transformer_layer_recompute(bool, default False) - graph_transformer.number_recompute_layers(bool, default False) - graph_transformer.propagate_cast_ops_config (dict): - graph_transformer.propagate_cast_ops_config.strategy(PropagateCastOpsStrategy, default FLOOD_FILL) - Specify the choice of the cast propagation optimization strategy, either, NONE, INSERT_AND_REDUCE or FLOOD_FILL. - NONE strategy does not perform any cast propagation transformation on the graph, although other optimizations - locally change cast operations, for example, in order to fuse Transpose and MatMul nodes, the TransposeMatMulFunsion optimization could - interchange Transpose and Cast if the Cast node exists between Transpose and MatMul. - INSERT_AND_REDUCE strategy inserts and reduces cast operations around the nodes with allowed opcodes. - FLOOD_FILL strategy expands float16 regions in the graph using the allowed opcodes, and unlike - INSERT_AND_REDUCE does not touch opcodes outside expanded float16 region. - graph_transformer.propagate_cast_ops_config.level(integer, default 1) - Optimize by moving Cast operations if propagate_cast_ops_level is non-negative. - Use predetermined list of opcodes considered safe to move before/after cast operation - if propagate_cast_ops_level is positive and use propagate_cast_ops_allow otherwise. - graph_transformer.propagate_cast_ops_config.allow(list of str, []) - List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero. - attn_dropout_recompute (bool, default is False): - enable recomputing attention dropout to save memory - gelu_recompute (bool, default is False): - enable recomputing Gelu activation output to save memory - transformer_layer_recompute (bool, default is False): - enable recomputing transformer layerwise to save memory - number_recompute_layers (int, default is 0) - number of layers to apply transformer_layer_recompute, by default system will - apply recompute to all the layers, except for the last one - utils (dict): - miscellaneous options - utils.frozen_weights (list of str, []): - list of model parameter names to skip training (weights don't change) - utils.grad_norm_clip (bool, default is True): - enables gradient norm clipping for 'AdamOptimizer' and 'LambOptimizer' - utils.memory_efficient_gradient (bool, default is False): - enables use of memory aware gradient builder. - utils.run_symbolic_shape_infer (bool, default is False): - runs symbolic shape inference on the model - debug (dict): - debug options - debug.deterministic_compute (bool, default is False) - forces compute to be deterministic accross runs - debug.check_model_export (bool, default is False) - compares PyTorch model outputs with ONNX model outputs in inference before the first - train step to ensure successful model export - debug.graph_save_paths (dict): - paths used for dumping ONNX graphs for debugging purposes - debug.graph_save_paths.model_after_graph_transforms_path (str, default is "") - path to export the ONNX graph after training-related graph transforms have been applied. - No output when it is empty. - debug.graph_save_paths.model_with_gradient_graph_path (str, default is "") - path to export the ONNX graph with the gradient graph added. No output when it is empty. - debug.graph_save_paths.model_with_training_graph_path (str, default is "") - path to export the training ONNX graph with forward, gradient and optimizer nodes. - No output when it is empty. - debug.graph_save_paths.model_with_training_graph_after_optimization_path (str, default is "") - outputs the optimized training graph to the path if nonempty. - _internal_use (dict): - internal options, possibly undocumented, that might be removed without notice - _internal_use.enable_internal_postprocess (bool, default is True): - enable internal internal post processing of the ONNX model - _internal_use.extra_postprocess (callable, default is None) - a functor to postprocess the ONNX model and return a new ONNX model. - It does not override :py:attr:`._internal_use.enable_internal_postprocess`, but complement it - _internal_use.onnx_opset_version (int, default is 14): - ONNX opset version used during model exporting. - _internal_use.enable_onnx_contrib_ops (bool, default is True) - enable PyTorch to export nodes as contrib ops in ONNX. - This flag may be removed anytime in the future. - session_options (onnxruntime.SessionOptions): - The SessionOptions instance that TrainingSession will use. - provider_options (dict): - The provider_options for customized execution providers. it is dict map from EP name to - a key-value pairs, like {'EP1' : {'key1' : 'val1'}, ....} - - Example: - .. code-block:: python - - opts = ORTTrainerOptions({ - 'batch' : { - 'gradient_accumulation_steps' : 128 - }, - 'device' : { - 'id' : 'cuda:0', - 'mem_limit' : 2*1024*1024*1024, - }, - 'lr_scheduler' : optim.lr_scheduler.LinearWarmupLRScheduler(), - 'mixed_precision' : { - 'enabled': True, - 'loss_scaler': amp.LossScaler(loss_scale=float(1 << 16)) - } - }) - fp16_enabled = opts.mixed_precision.enabled - """ - - def __init__(self, options={}): # noqa: B006 - # Keep a copy of original input for debug - self._original_opts = dict(options) - - # Used for logging purposes - self._main_class_name = self.__class__.__name__ - - # Validates user input - self._validated_opts = dict(self._original_opts) - validator = ORTTrainerOptionsValidator(_ORTTRAINER_OPTIONS_SCHEMA) - self._validated_opts = validator.validated(self._validated_opts) - if self._validated_opts is None: - raise ValueError(f"Invalid options: {validator.errors}") - - # Convert dict in object - for k, v in self._validated_opts.items(): - setattr(self, k, self._wrap(v)) - - def __repr__(self): - return "{%s}" % str( - ", ".join( - f"'{k}': {v!r}" - for (k, v) in self.__dict__.items() - if k not in ["_original_opts", "_validated_opts", "_main_class_name"] - ) - ) - - def _wrap(self, v): - if isinstance(v, (tuple, list, set, frozenset)): - return type(v)([self._wrap(i) for i in v]) - else: - return _ORTTrainerOptionsInternal(self._main_class_name, v) if isinstance(v, dict) else v - - -class _ORTTrainerOptionsInternal(ORTTrainerOptions): - r"""Internal class used by ONNX Runtime training backend for input validation - - NOTE: Users MUST NOT use this class in any way! - """ - - def __init__(self, main_class_name, options): - # Used for logging purposes - self._main_class_name = main_class_name - # We don't call super().__init__(options) here but still called it "_validated_opts" - # instead of "_original_opts" because it has been validated in the top-level - # ORTTrainerOptions's constructor. - self._validated_opts = dict(options) - # Convert dict in object - for k, v in dict(options).items(): - setattr(self, k, self._wrap(v)) - - -class ORTTrainerOptionsValidator(cerberus.Validator): - _LR_SCHEDULER = cerberus.TypeDefinition("lr_scheduler", (lr_scheduler._LRScheduler,), ()) - _LOSS_SCALER = cerberus.TypeDefinition("loss_scaler", (loss_scaler.LossScaler,), ()) - - _SESSION_OPTIONS = cerberus.TypeDefinition("session_options", (ort.SessionOptions,), ()) - - _PROPAGATE_CAST_OPS_STRATEGY = cerberus.TypeDefinition( - "propagate_cast_ops_strategy", (PropagateCastOpsStrategy,), () - ) - - types_mapping = cerberus.Validator.types_mapping.copy() - types_mapping["lr_scheduler"] = _LR_SCHEDULER - types_mapping["loss_scaler"] = _LOSS_SCALER - types_mapping["session_options"] = _SESSION_OPTIONS - types_mapping["propagate_cast_ops_strategy"] = _PROPAGATE_CAST_OPS_STRATEGY - - -def _check_is_callable(field, value, error): - result = False - try: - # Python 3 - result = value is None or callable(value) - except Exception: - # Python 3 but < 3.2 - if hasattr(value, "__call__"): # noqa: B004 - result = True - if not result: - error(field, "Must be callable or None") - - -_ORTTRAINER_OPTIONS_SCHEMA = { - "batch": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": {"gradient_accumulation_steps": {"type": "integer", "min": 1, "default": 1}}, - }, - "device": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "id": {"type": "string", "default": "cuda"}, - "mem_limit": {"type": "integer", "min": 0, "default": 0}, - }, - }, - "distributed": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "world_rank": {"type": "integer", "min": 0, "default": 0}, - "world_size": {"type": "integer", "min": 1, "default": 1}, - "local_rank": {"type": "integer", "min": 0, "default": 0}, - "data_parallel_size": {"type": "integer", "min": 1, "default": 1}, - "horizontal_parallel_size": {"type": "integer", "min": 1, "default": 1}, - "pipeline_parallel": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "pipeline_parallel_size": {"type": "integer", "min": 1, "default": 1}, - "num_pipeline_micro_batches": {"type": "integer", "min": 1, "default": 1}, - "pipeline_cut_info_string": {"type": "string", "default": ""}, - "sliced_schema": { - "type": "dict", - "default_setter": lambda _: {}, - "keysrules": {"type": "string"}, - "valuesrules": {"type": "list", "schema": {"type": "integer"}}, - }, - "sliced_axes": { - "type": "dict", - "default_setter": lambda _: {}, - "keysrules": {"type": "string"}, - "valuesrules": {"type": "integer"}, - }, - "sliced_tensor_names": {"type": "list", "schema": {"type": "string"}, "default": []}, - }, - }, - "allreduce_post_accumulation": {"type": "boolean", "default": False}, - "deepspeed_zero_optimization": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "stage": {"type": "integer", "min": 0, "max": 1, "default": 0}, - }, - }, - "enable_adasum": {"type": "boolean", "default": False}, - }, - }, - "lr_scheduler": {"type": "lr_scheduler", "nullable": True, "default": None}, - "mixed_precision": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "enabled": {"type": "boolean", "default": False}, - "loss_scaler": {"type": "loss_scaler", "nullable": True, "default": None}, - }, - }, - "graph_transformer": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "attn_dropout_recompute": {"type": "boolean", "default": False}, - "gelu_recompute": {"type": "boolean", "default": False}, - "transformer_layer_recompute": {"type": "boolean", "default": False}, - "number_recompute_layers": {"type": "integer", "min": 0, "default": 0}, - "propagate_cast_ops_config": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "strategy": { - "type": "propagate_cast_ops_strategy", - "nullable": True, - "default": PropagateCastOpsStrategy.FLOOD_FILL, - }, - "level": {"type": "integer", "min": -1, "default": 1}, - "allow": {"type": "list", "schema": {"type": "string"}, "default": []}, - }, - }, - }, - }, - "utils": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "frozen_weights": {"type": "list", "default": []}, - "grad_norm_clip": {"type": "boolean", "default": True}, - "memory_efficient_gradient": {"type": "boolean", "default": False}, - "run_symbolic_shape_infer": {"type": "boolean", "default": False}, - }, - }, - "debug": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "deterministic_compute": {"type": "boolean", "default": False}, - "check_model_export": {"type": "boolean", "default": False}, - "graph_save_paths": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "model_after_graph_transforms_path": {"type": "string", "default": ""}, - "model_with_gradient_graph_path": {"type": "string", "default": ""}, - "model_with_training_graph_path": {"type": "string", "default": ""}, - "model_with_training_graph_after_optimization_path": {"type": "string", "default": ""}, - }, - }, - }, - }, - "_internal_use": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "schema": { - "enable_internal_postprocess": {"type": "boolean", "default": True}, - "extra_postprocess": {"check_with": _check_is_callable, "nullable": True, "default": None}, - "onnx_opset_version": {"type": "integer", "min": 12, "max": 14, "default": 14}, - "enable_onnx_contrib_ops": {"type": "boolean", "default": True}, - }, - }, - "provider_options": { - "type": "dict", - "default_setter": lambda _: {}, - "required": False, - "allow_unknown": True, - "schema": {}, - }, - "session_options": {"type": "session_options", "nullable": True, "default": None}, -} diff --git a/orttraining/orttraining/python/training/postprocess.py b/orttraining/orttraining/python/training/postprocess.py deleted file mode 100644 index 6c2adb6af7978..0000000000000 --- a/orttraining/orttraining/python/training/postprocess.py +++ /dev/null @@ -1,478 +0,0 @@ -import os.path # noqa: F401 -import struct -import sys # noqa: F401 - -import numpy as np # noqa: F401 -import onnx -from onnx import * # noqa: F403 -from onnx import helper, numpy_helper # noqa: F401 - - -def run_postprocess(model): - # this post pass is not required for pytorch >= 1.5 - # where add_node_name in torch.onnx.export is default to True - model = add_name(model) - - # this post pass is not required for pytorch > 1.6 - model = fuse_softmaxNLL_to_softmaxCE(model) - - model = fix_expand_shape(model) - model = fix_expand_shape_pt_1_5(model) - return model - - -def find_input_node(model, arg): - result = [] - for node in model.graph.node: - for output in node.output: - if output == arg: - result.append(node) - return result[0] if len(result) == 1 else None - - -def find_output_node(model, arg): - result = [] - for node in model.graph.node: - for input in node.input: - if input == arg: - result.append(node) - return result[0] if len(result) == 1 else result - - -def add_name(model): - i = 0 - for node in model.graph.node: - node.name = "%s_%d" % (node.op_type, i) - i += 1 - return model - - -# Expand Shape PostProcess - - -def fix_expand_shape(model): - expand_nodes = [n for n in model.graph.node if n.op_type == "Expand"] - model_inputs_names = [i.name for i in model.graph.input] - - for expand_node in expand_nodes: - shape = find_input_node(model, expand_node.input[1]) - if shape.op_type == "Shape": - # an expand subgraph - # Input Input2 - # | | - # | Shape - # | | - # |__ __| - # | | - # Expand - # | - # output - # - # Only if Input2 is one of the model inputs, assign Input2's shape to output of expand. - shape_input_name = shape.input[0] - if shape_input_name in model_inputs_names: - index = model_inputs_names.index(shape_input_name) - expand_out = model.graph.value_info.add() - expand_out.name = expand_node.output[0] - expand_out.type.CopyFrom(model.graph.input[index].type) - return model - - -def fix_expand_shape_pt_1_5(model): - # expand subgraph - # Constant - # + - # ConstantOfShape - # | + | - # | + | - # (Reshape subgraph) Mul | - # |___ _________| | - # + | | | - # + Equal | - # +++++|++++++++++++++|++ - # |____________ | + - # | | + - # (subgraph) Where - # | | - # |_____ ___________| - # | | - # Expand - # | - # output - # - # where the Reshape subgraph is - # - # Input - # | | - # | |___________________ - # | | - # Shape Constant Shape Constant - # | ______| | ______| - # | | | | - # Gather Gather - # | | - # Unsqueeze Unsqueeze - # | | - # | ..Number of dims.. | - # | _________________| - # |...| - # Concat Constant - # | | - # |______ __________________| - # | | - # Reshape - # | - # output - # - # This pass will copy Input's shape to the output of Expand. - expand_nodes = [n for n in model.graph.node if n.op_type == "Expand"] - model_inputs_names = [i.name for i in model.graph.input] - - for expand_node in expand_nodes: - n_where = find_input_node(model, expand_node.input[1]) - if n_where.op_type != "Where": - continue - - n_equal = find_input_node(model, n_where.input[0]) - n_cos = find_input_node(model, n_where.input[1]) - n_reshape = find_input_node(model, n_where.input[2]) - - if n_equal.op_type != "Equal" or n_cos.op_type != "ConstantOfShape" or n_reshape.op_type != "Reshape": - continue - - n_reshape_e = find_input_node(model, n_equal.input[0]) - n_mul = find_input_node(model, n_equal.input[1]) - if n_reshape_e != n_reshape or n_mul.op_type != "Mul": - continue - - n_cos_m = find_input_node(model, n_mul.input[0]) - n_constant = find_input_node(model, n_mul.input[1]) - if n_cos_m != n_cos or n_constant.op_type != "Constant": - continue - - n_concat = find_input_node(model, n_reshape.input[0]) - n_constant_r = find_input_node(model, n_reshape.input[1]) - if n_concat.op_type != "Concat" or n_constant_r.op_type != "Constant": - continue - - n_input_candidates = [] - for concat_in in n_concat.input: - n_unsqueeze = find_input_node(model, concat_in) - if n_unsqueeze.op_type != "Unsqueeze": - break - n_gather = find_input_node(model, n_unsqueeze.input[0]) - if n_gather.op_type != "Gather": - break - n_shape = find_input_node(model, n_gather.input[0]) - n_constant_g = find_input_node(model, n_gather.input[1]) - if n_shape.op_type != "Shape" or n_constant_g.op_type != "Constant": - break - n_input = n_shape.input[0] - if n_input not in model_inputs_names: - break - n_input_candidates.append(n_input) - - if not n_input_candidates or not all(elem == n_input_candidates[0] for elem in n_input_candidates): - continue - - index = model_inputs_names.index(n_input_candidates[0]) - expand_out = model.graph.value_info.add() - expand_out.name = expand_node.output[0] - expand_out.type.CopyFrom(model.graph.input[index].type) - return model - - -# LayerNorm PostProcess - - -def find_nodes(graph, op_type): - nodes = [] - for node in graph.node: - if node.op_type == op_type: - nodes.append(node) - return nodes - - -def is_type(node, op_type): - if node is None or isinstance(node, list): - return False - return node.op_type == op_type - - -def add_const(model, name, output, t_value=None, f_value=None): - const_node = model.graph.node.add() - const_node.op_type = "Constant" - const_node.name = name - const_node.output.extend([output]) - attr = const_node.attribute.add() - attr.name = "value" - if t_value is not None: - attr.type = 4 - attr.t.CopyFrom(t_value) - else: - attr.type = 1 - attr.f = f_value - return const_node - - -def layer_norm_transform(model): - # DEPRECATED: This pass is no longer needed as the transform is handled at the backend. - # Converting below subgraph - # - # input - # | - # ReduceMean - # | - # Sub Constant - # _||_____ | - # | | | - # | | | - # | (optional) Cast (optional) Cast - # | | | - # | | ____________________| - # | | | - # | Pow - # | | - # | ReduceMean - # | | - # | Add - # | | - # |__ __Sqrt - # | | - # Div (weight) - # | | - # | _____| - # | | - # Mul (bias) - # | | - # | _____| - # | | - # Add - # | - # output - # - # to the below subgraph - # - # input (weight) (bias) - # | | | - # | _______| | - # | | ________________| - # | | | - # LayerNormalization - # | - # output - graph = model.graph - - nodes_ReduceMean = find_nodes(graph, "ReduceMean") # noqa: N806 - - id = 0 - layer_norm_nodes = [] - remove_nodes = [] - for reduce_mean in nodes_ReduceMean: - # check that reduce_mean output is Sub - sub = find_output_node(model, reduce_mean.output[0]) - if not is_type(sub, "Sub"): - continue - - # check that sub output[0] is Div and output[1] is Pow - pow, div = find_output_node(model, sub.output[0]) - if is_type(pow, "Cast"): - # During an update in PyTorch, Cast nodes are inserted between Sub and Pow. - remove_nodes += [pow] - pow = find_output_node(model, pow.output[0]) - if not is_type(pow, "Pow"): - continue - cast_pow = find_input_node(model, pow.input[1]) - if not is_type(cast_pow, "Cast"): - continue - remove_nodes += [cast_pow] - if not is_type(div, "Div") or not is_type(pow, "Pow"): - continue - - # check that pow ouput is ReduceMean - reduce_mean2 = find_output_node(model, pow.output[0]) - if not is_type(reduce_mean2, "ReduceMean"): - continue - - # check that reduce_mean2 output is Add - add = find_output_node(model, reduce_mean2.output[0]) - if not is_type(add, "Add"): - continue - - # check that add output is Sqrt - sqrt = find_output_node(model, add.output[0]) - if not is_type(sqrt, "Sqrt"): - continue - - # check that sqrt output is div - if div != find_output_node(model, sqrt.output[0]): - continue - - # check if div output is Mul - optional_mul = find_output_node(model, div.output[0]) - if not is_type(optional_mul, "Mul"): - optional_mul = None - continue # default bias and weight not supported - - # check if mul output is Add - if optional_mul is not None: - optional_add = find_output_node(model, optional_mul.output[0]) - else: - optional_add = find_output_node(model, div.output[0]) - if not is_type(optional_add, "Add"): - optional_add = None - continue # default bias and weight not supported - - # add nodes to remove_nodes - remove_nodes.extend([reduce_mean, sub, div, pow, reduce_mean2, add, sqrt]) - - # create LayerNorm node - layer_norm_input = [] - layer_norm_output = [] - - layer_norm_input.append(reduce_mean.input[0]) - - if optional_mul is not None: - remove_nodes.append(optional_mul) - weight = optional_mul.input[1] - layer_norm_input.append(weight) - - if optional_add is not None: - remove_nodes.append(optional_add) - bias = optional_add.input[1] - layer_norm_input.append(bias) - - if optional_add is not None: - layer_norm_output.append(optional_add.output[0]) - elif optional_mul is not None: - layer_norm_output.append(optional_mul.output[0]) - else: - layer_norm_output.append(div.output[0]) - - layer_norm_output.append("saved_mean_" + str(id)) - layer_norm_output.append("saved_inv_std_var_" + str(id)) - - epsilon_node = find_input_node(model, add.input[1]) - epsilon = epsilon_node.attribute[0].t.raw_data - epsilon = struct.unpack("f", epsilon)[0] - - layer_norm = helper.make_node( - "LayerNormalization", - layer_norm_input, - layer_norm_output, - "LayerNormalization_" + str(id), - None, - axis=reduce_mean.attribute[0].ints[0], - epsilon=epsilon, - ) - layer_norm_nodes.append(layer_norm) - id += 1 - - # remove orphan constant nodes - for constant in graph.node: - if constant.op_type == "Constant" and constant not in remove_nodes: - is_orphan = True - for out_name in constant.output: - out = find_output_node(model, out_name) - if out not in remove_nodes: - is_orphan = False - if is_orphan: - remove_nodes.append(constant) - - all_nodes = [] - for node in graph.node: - if node not in remove_nodes: - all_nodes.append(node) - - for node in layer_norm_nodes: - all_nodes.append(node) # noqa: PERF402 - - graph.ClearField("node") - graph.node.extend(all_nodes) - return model - - -# Fuse SoftmaxCrossEntropy - - -def fuse_softmaxNLL_to_softmaxCE(onnx_model): # noqa: N802 - # Converting below subgraph - # - # (subgraph) - # | - # LogSoftmax (target) (optional weight) - # | | | - # nll_loss/NegativeLogLikelihoodLoss - # | - # output - # - # to the following - # - # (subgraph) (target) (optional weight) - # | | _____| - # | | | - # SparseSoftmaxCrossEntropy - # | - # output - nll_count = 0 - while True: - nll_count = nll_count + 1 - nll_loss_node = None - nll_loss_node_index = 0 - for nll_loss_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007 - if node.op_type == "nll_loss" or node.op_type == "NegativeLogLikelihoodLoss": - nll_loss_node = node - break - - if nll_loss_node is None: - break - - softmax_node = None - softmax_node_index = 0 - label_input_name = None - weight_input_name = None - for softmax_node_index, node in enumerate(onnx_model.graph.node): # noqa: B007 - if node.op_type == "LogSoftmax": - # has to be connected to nll_loss - if len(nll_loss_node.input) > 2: - weight_input_name = nll_loss_node.input[2] - if node.output[0] == nll_loss_node.input[0]: - softmax_node = node - label_input_name = nll_loss_node.input[1] - break - elif node.output[0] == nll_loss_node.input[1]: - softmax_node = node - label_input_name = nll_loss_node.input[0] - break - else: - if softmax_node is not None: - break - - if softmax_node is None: - break - - # delete nll_loss and LogSoftmax nodes in order - if nll_loss_node_index < softmax_node_index: - del onnx_model.graph.node[softmax_node_index] - del onnx_model.graph.node[nll_loss_node_index] - else: - del onnx_model.graph.node[nll_loss_node_index] - del onnx_model.graph.node[softmax_node_index] - - probability_output_name = softmax_node.output[0] - node = onnx_model.graph.node.add() - inputs = ( - [softmax_node.input[0], label_input_name, weight_input_name] - if weight_input_name - else [softmax_node.input[0], label_input_name] - ) - node.CopyFrom( - onnx.helper.make_node( - "SparseSoftmaxCrossEntropy", - inputs, - [nll_loss_node.output[0], probability_output_name], - "nll_loss_node_" + str(nll_count), - ) - ) - - return onnx_model diff --git a/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py b/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py deleted file mode 100644 index f57f55d14eb1b..0000000000000 --- a/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py +++ /dev/null @@ -1,144 +0,0 @@ -import sys -import threading -import time - - -class OutputGrabber: - """ - Class used to grab standard output or another stream. - """ - - escape_char = "\b" - - def __init__(self, stream=None, threaded=False): - self.origstream = stream - self.threaded = threaded - if self.origstream is None: - self.origstream = sys.stdout - self.origstreamfd = self.origstream.fileno() - self.capturedtext = "" - # Create a pipe so the stream can be captured: - self.pipe_out, self.pipe_in = os.pipe() - - def __enter__(self): - self.start() - return self - - def __exit__(self, type, value, traceback): - self.stop() - - def start(self): - """ - Start capturing the stream data. - """ - self.capturedtext = "" - # Save a copy of the stream: - self.streamfd = os.dup(self.origstreamfd) - # Replace the original stream with our write pipe: - os.dup2(self.pipe_in, self.origstreamfd) - if self.threaded: - # Start thread that will read the stream: - self.workerThread = threading.Thread(target=self.readOutput) - self.workerThread.start() - # Make sure that the thread is running and os.read() has executed: - time.sleep(0.01) - - def stop(self): - """ - Stop capturing the stream data and save the text in `capturedtext`. - """ - # Print the escape character to make the readOutput method stop: - self.origstream.write(self.escape_char) - # Flush the stream to make sure all our data goes in before - # the escape character: - self.origstream.flush() - if self.threaded: - # wait until the thread finishes so we are sure that - # we have until the last character: - self.workerThread.join() - else: - self.readOutput() - # Close the pipe: - os.close(self.pipe_in) - os.close(self.pipe_out) - # Restore the original stream: - os.dup2(self.streamfd, self.origstreamfd) - # Close the duplicate stream: - os.close(self.streamfd) - - def readOutput(self): - """ - Read the stream data (one byte at a time) - and save the text in `capturedtext`. - """ - while True: - char = os.read(self.pipe_out, 1).decode(self.origstream.encoding) - if not char or self.escape_char in char: - break - self.capturedtext += char - - -import os # noqa: E402 -import unittest # noqa: E402 - -import numpy as np # noqa: E402, F401 -import torch # noqa: E402 -import torch.nn as nn # noqa: E402 -import torch.nn.functional as F # noqa: E402 - -from onnxruntime.capi import _pybind_state as torch_ort_eager # noqa: E402, F401 -from onnxruntime.training import optim, orttrainer, orttrainer_options # noqa: E402, F401 - - -def my_loss(x, target): - return F.nll_loss(F.log_softmax(x, dim=1), target) - - -class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, x, target): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - return my_loss(out, target) - - -class OrtEPTests(unittest.TestCase): - def test_external_graph_transformer_triggering(self): - input_size = 784 - hidden_size = 500 - num_classes = 10 - batch_size = 128 - model = NeuralNet(input_size, hidden_size, num_classes) - - model_desc = { - "inputs": [ - ("x", [batch_size, input_size]), - ( - "target", - [ - batch_size, - ], - ), - ], - "outputs": [("loss", [], True)], - } - optim_config = optim.SGDConfig() - opts = orttrainer.ORTTrainerOptions({"device": {"id": "cpu"}}) - model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - # because orttrainer is lazy initialized, feed in a random data to trigger the graph transformer - data = torch.rand(batch_size, input_size) - target = torch.randint(0, 10, (batch_size,)) - - with OutputGrabber() as out: - model.train_step(data, target) - assert "******************Trigger Customized Graph Transformer: MyGraphTransformer!" in out.capturedtext - - -if __name__ == "__main__": - unittest.main() diff --git a/orttraining/orttraining/test/external_transformer/test_exeternal_transformers/test_external_transformers.cc b/orttraining/orttraining/test/external_transformer/test_exeternal_transformers/test_external_transformers.cc deleted file mode 100644 index 00e933dd14914..0000000000000 --- a/orttraining/orttraining/test/external_transformer/test_exeternal_transformers/test_external_transformers.cc +++ /dev/null @@ -1,35 +0,0 @@ -#include "core/optimizer/rewrite_rule.h" -#include "orttraining/core/optimizer/graph_transformer_registry.h" -#include "onnx/defs/schema.h" -#include -#include - -namespace onnxruntime { -namespace training { - -class MyRewriteRule : public RewriteRule { - public: - MyRewriteRule() noexcept - : RewriteRule("MyRewriteRule") { - } - std::vector TargetOpTypes() const noexcept override { - return {}; - } - - private: - bool SatisfyCondition(const Graph& /*graph*/, const Node& /*node*/, const logging::Logger& /*logger*/) const override { - return true; - } - - Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/, const logging::Logger& /*logger*/) const override { - std::cout << "******************Trigger Customized Graph Transformer: MyGraphTransformer!" << std::endl; - return Status::OK(); - } -}; - -void RegisterTrainingExternalTransformers() { - ONNX_REGISTER_EXTERNAL_REWRITE_RULE(MyRewriteRule, Level1, true); -} - -} // namespace training -} // namespace onnxruntime diff --git a/orttraining/orttraining/test/python/_test_commons.py b/orttraining/orttraining/test/python/_test_commons.py index 1413d59096832..fb7e62551de63 100644 --- a/orttraining/orttraining/test/python/_test_commons.py +++ b/orttraining/orttraining/test/python/_test_commons.py @@ -1,26 +1,7 @@ -import copy -import math import os import subprocess import sys -import numpy as np -import onnx -import torch -from numpy.testing import assert_allclose - -import onnxruntime -from onnxruntime.training import _utils, optim - - -def _single_run(execution_file, scenario, checkopint_dir=None): - cmd = [sys.executable, execution_file] - if scenario: - cmd += ["--scenario", scenario] - if checkopint_dir: - cmd += ["--checkpoint_dir", checkopint_dir] - assert subprocess.call(cmd) == 0 - def is_windows(): return sys.platform.startswith("win") @@ -46,197 +27,3 @@ def run_subprocess(args, cwd=None, capture=False, dll_path=None, shell=False, en if log: log.debug("Subprocess completed. Return code=" + str(completed_process.returncode)) return completed_process - - -def legacy_constant_lr_scheduler(global_step, initial_lr, total_steps, warmup): - num_warmup_steps = warmup * total_steps - if global_step < num_warmup_steps: - new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps)) - else: - new_lr = initial_lr - return new_lr - - -def legacy_cosine_lr_scheduler(global_step, initial_lr, total_steps, warmup, cycles): - num_warmup_steps = warmup * total_steps - if global_step < num_warmup_steps: - new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps)) - else: - progress = float(global_step - num_warmup_steps) / float(max(1, total_steps - num_warmup_steps)) - new_lr = initial_lr * max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(cycles) * 2.0 * progress))) - return new_lr - - -def legacy_linear_lr_scheduler(global_step, initial_lr, total_steps, warmup): - num_warmup_steps = warmup * total_steps - if global_step < num_warmup_steps: - new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps)) - else: - new_lr = initial_lr * max(0.0, float(total_steps - global_step) / float(max(1, total_steps - num_warmup_steps))) - return new_lr - - -def legacy_poly_lr_scheduler(global_step, initial_lr, total_steps, warmup, power, lr_end): - num_warmup_steps = warmup * total_steps - if global_step < num_warmup_steps: - new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps)) - elif global_step > total_steps: - new_lr = lr_end - else: - lr_range = initial_lr - lr_end - decay_steps = total_steps - num_warmup_steps - pct_remaining = 1 - (global_step - num_warmup_steps) / decay_steps - decay = lr_range * pct_remaining**power + lr_end - new_lr = decay - return new_lr - - -def generate_dummy_optim_state(model, optimizer): - np.random.seed(0) - if not (isinstance(optimizer, (optim.AdamConfig, optim.LambConfig))): - return dict() - - moment_keys = ["Moment_1", "Moment_2"] - uc_key = "Update_Count" - step_key = "Step" - shared_state_key = "shared_optimizer_state" - - optim_state = dict() - weight_shape_map = dict() - if isinstance(model, torch.nn.Module): - weight_shape_map = {name: param.size() for name, param in model.named_parameters()} - elif isinstance(model, onnx.ModelProto): - weight_shape_map = {n.name: n.dims for n in model.graph.initializer} - else: - raise ValueError("'model' must be either 'torch.nn.Module' or 'onnx.ModelProto'") - - for weight_name, weight_shape in weight_shape_map.items(): - per_weight_state = dict() - for moment in moment_keys: - per_weight_state[moment] = np.random.uniform(-2, 2, weight_shape).astype(np.float32) - if isinstance(optimizer, optim.AdamConfig): - per_weight_state[uc_key] = np.full([1], 5, dtype=np.int64) - optim_state[weight_name] = copy.deepcopy(per_weight_state) - if isinstance(optimizer, optim.LambConfig): - step_val = np.full([1], 5, dtype=np.int64) - optim_state[shared_state_key] = {step_key: step_val} - return {"optimizer": optim_state, "trainer_options": {"optimizer_name": optimizer.name}} - - -def _load_pytorch_transformer_model(device, dynamic_axes=False, legacy_api=False, data_dir=None): - # Loads external Pytorch TransformerModel into utils - root = "samples" - if not os.path.exists(root): - root = os.path.normpath( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "..", "samples") - ) - if not os.path.exists(root): - raise FileNotFoundError("Unable to find folder 'samples', tried %r." % root) - pytorch_transformer_path = os.path.join(root, "python", "training", "orttrainer", "pytorch_transformer") - pt_model_path = os.path.join(pytorch_transformer_path, "pt_model.py") - pt_model = _utils.import_module_from_file(pt_model_path) - ort_utils_path = os.path.join(pytorch_transformer_path, "ort_utils.py") - ort_utils = _utils.import_module_from_file(ort_utils_path) - utils_path = os.path.join(pytorch_transformer_path, "utils.py") - utils = _utils.import_module_from_file(utils_path) - - # Modeling - model = pt_model.TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device) - my_loss = ort_utils.my_loss - if legacy_api: - if dynamic_axes: - model_desc = ort_utils.legacy_transformer_model_description_dynamic_axes() - else: - model_desc = ort_utils.legacy_transformer_model_description() - else: - if dynamic_axes: - model_desc = ort_utils.transformer_model_description_dynamic_axes() - else: - model_desc = ort_utils.transformer_model_description() - - # Preparing data - train_data, val_data, test_data = utils.prepare_data(device, 20, 20, data_dir) - return model, model_desc, my_loss, utils.get_batch, train_data, val_data, test_data - - -def generate_random_input_from_bart_model_desc(desc, seed=1, device="cuda:0"): - """Generates a sample input for the BART model using the model desc""" - - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - dtype = torch.int64 - vocab_size = 30528 - sample_input = [] - for _index, input in enumerate(desc["inputs"]): - size = [] - for s in input[1]: - if isinstance(s, (int)): - size.append(s) - else: - size.append(1) - sample_input.append(torch.randint(0, vocab_size, tuple(size), dtype=dtype).to(device)) - return sample_input - - -def _load_bart_model(): - bart_onnx_model_path = os.path.join("testdata", "bart_tiny.onnx") - model = onnx.load(bart_onnx_model_path) - batch = 2 - seq_len = 1024 - model_desc = { - "inputs": [ - ( - "src_tokens", - [batch, seq_len], - ), - ( - "prev_output_tokens", - [batch, seq_len], - ), - ( - "target", - [batch * seq_len], - ), - ], - "outputs": [("loss", [], True)], - } - - return model, model_desc - - -def assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint, reshape_states=False): - """Assert that the two ORTTrainer (hierarchical) state dictionaries are very close for all states""" - - assert ("model" in state_dict_pre_checkpoint) == ("model" in state_dict_post_checkpoint) - assert ("optimizer" in state_dict_pre_checkpoint) == ("optimizer" in state_dict_post_checkpoint) - - if "model" in state_dict_pre_checkpoint: - for model_state_key in state_dict_pre_checkpoint["model"]["full_precision"]: - if reshape_states: - assert_allclose( - state_dict_pre_checkpoint["model"]["full_precision"][model_state_key], - state_dict_post_checkpoint["model"]["full_precision"][model_state_key].reshape( - state_dict_pre_checkpoint["model"]["full_precision"][model_state_key].shape - ), - ) - else: - assert_allclose( - state_dict_pre_checkpoint["model"]["full_precision"][model_state_key], - state_dict_post_checkpoint["model"]["full_precision"][model_state_key], - ) - - if "optimizer" in state_dict_pre_checkpoint: - for model_state_key in state_dict_pre_checkpoint["optimizer"]: - for optimizer_state_key in state_dict_pre_checkpoint["optimizer"][model_state_key]: - if reshape_states: - assert_allclose( - state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key], - state_dict_post_checkpoint["optimizer"][model_state_key][optimizer_state_key].reshape( - state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key].shape - ), - ) - else: - assert_allclose( - state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key], - state_dict_post_checkpoint["optimizer"][model_state_key][optimizer_state_key], - ) diff --git a/orttraining/orttraining/test/python/_test_helpers.py b/orttraining/orttraining/test/python/_test_helpers.py index a9a4c7b1cc2ef..8f2a18b5ec00b 100644 --- a/orttraining/orttraining/test/python/_test_helpers.py +++ b/orttraining/orttraining/test/python/_test_helpers.py @@ -1,30 +1,11 @@ import copy import os -import numpy as np import torch from numpy.testing import assert_allclose -from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer -from onnxruntime.training import orttrainer - -try: - from onnxruntime.training.ortmodule import ORTModule - from onnxruntime.training.ortmodule._fallback import ORTModuleInitException - from onnxruntime.training.ortmodule._graph_execution_manager_factory import ( # noqa: F401 - GraphExecutionManagerFactory, - ) -except ImportError: - # Some pipelines do not contain ORTModule - pass -except Exception as e: - from onnxruntime.training.ortmodule._fallback import ORTModuleInitException - - if isinstance(e, ORTModuleInitException): - # ORTModule is present but not ready to run - # That is OK because this file is also used by ORTTrainer tests - pass - raise +from onnxruntime.training.ortmodule import ORTModule +from onnxruntime.training.ortmodule._graph_execution_manager_factory import GraphExecutionManagerFactory # noqa: F401 def is_all_or_nothing_fallback_enabled(model, policy=None): @@ -66,103 +47,6 @@ def assert_model_outputs(output_a, output_b, verbose=False, rtol=1e-7, atol=0): assert_allclose(output_a, output_b, rtol=rtol, atol=atol, err_msg="Model output value mismatch") -def assert_onnx_weights(model_a, model_b, verbose=False, rtol=1e-7, atol=0): - r"""Asserts whether weight difference between models a and b differences are within specified tolerance - - Compares the weights of two different ONNX models (model_a and model_b) - and raises AssertError when they diverge by more than atol or rtol - - Args: - model_a, model_b (ORTTrainer): Two instances of ORTTrainer with the same model structure - verbose (bool, default is False): if True, prints absolute difference for each weight - rtol (float, default is 1e-7): Max relative difference - atol (float, default is 1e-4): Max absolute difference - """ - assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, orttrainer.ORTTrainer) - state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b._training_session.get_state() - assert len(state_dict_a.items()) == len(state_dict_b.items()) - _assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol) - - -def assert_legacy_onnx_weights(model_a, model_b, verbose=False, rtol=1e-7, atol=0): - r"""Asserts whether weight difference between models a and b differences are within specified tolerance - - Compares the weights of a legacy model model_a and experimental model_b model - and raises AssertError when they diverge by more than atol or rtol. - - Args: - model_a (ORTTrainer): Instance of legacy ORTTrainer - model_b (ORTTrainer): Instance of experimental ORTTrainer - verbose (bool, default is False): if True, prints absolute difference for each weight. - rtol (float, default is 1e-7): Max relative difference - atol (float, default is 1e-4): Max absolute difference - """ - assert isinstance(model_a, orttrainer.ORTTrainer) and isinstance(model_b, Legacy_ORTTrainer) - state_dict_a, state_dict_b = model_a._training_session.get_state(), model_b.session.get_state() - assert len(state_dict_a.items()) == len(state_dict_b.items()) - _assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol) - - -def _assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol): - r"""Asserts whether dicts a and b value differences are within specified tolerance - - Compares the weights of two model's state_dict dicts and raises AssertError - when they diverge by more than atol or rtol - - Args: - model_a (ORTTrainer): Instance of legacy ORTTrainer - model_b (ORTTrainer): Instance of experimental ORTTrainer - verbose (bool, default is False): if True, prints absolute difference for each weight. - rtol (float, default is 1e-7): Max relative difference - atol (float, default is 1e-4): Max absolute difference - """ - - for (a_name, a_val), (_b_name, b_val) in zip(state_dict_a.items(), state_dict_b.items()): - np_a_vals = np.array(a_val).flatten() - np_b_vals = np.array(b_val).flatten() - assert np_a_vals.shape == np_b_vals.shape - if verbose: - print(f"Weight name: {a_name}: absolute difference: {np.abs(np_a_vals-np_b_vals).max()}") - assert_allclose(a_val, b_val, rtol=rtol, atol=atol, err_msg=f"Weight mismatch for {a_name}") - - -def assert_optim_state(expected_state, actual_state, rtol=1e-7, atol=0): - r"""Asserts whether optimizer state differences are within specified tolerance - - Compares the expected and actual optimizer states of dicts and raises AssertError - when they diverge by more than atol or rtol. - The optimizer dict is of the form: - model_weight_name: - { - "Moment_1": moment1_tensor, - "Moment_2": moment2_tensor, - "Update_Count": update_tensor # if optimizer is adam, absent otherwise - }, - ... - "shared_optimizer_state": # if optimizer is shared, absent otherwise. - So far, only lamb optimizer uses this. - { - "step": step_tensor # int array of size 1 - } - - Args: - expected_state (dict(dict())): Expected optimizer state - actual_state (dict(dict())): Actual optimizer state - rtol (float, default is 1e-7): Max relative difference - atol (float, default is 0): Max absolute difference - """ - assert expected_state.keys() == actual_state.keys() - for param_name, a_state in actual_state.items(): - for k, v in a_state.items(): - assert_allclose( - v, - expected_state[param_name][k], - rtol=rtol, - atol=atol, - err_msg=f"Optimizer state mismatch for param {param_name}, key {k}", - ) - - def is_dynamic_axes(model): # Check inputs for inp in model._torch_module._execution_manager(model._is_training())._onnx_models.optimized_model.graph.input: diff --git a/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py b/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py deleted file mode 100644 index d5298cf8e860e..0000000000000 --- a/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py +++ /dev/null @@ -1,325 +0,0 @@ -import os -import unittest - -import torch -import torch.nn as nn -from orttraining_test_bert_postprocess import postprocess_model -from orttraining_test_data_loader import create_ort_test_dataloader -from orttraining_test_transformers import BertForPreTraining, BertModelTest -from orttraining_test_utils import map_optimizer_attributes - -import onnxruntime -from onnxruntime.capi.ort_trainer import ( # noqa: F401 - IODescription, - LossScaler, - ModelDescription, - ORTTrainer, - generate_sample, -) - -torch.manual_seed(1) -onnxruntime.set_seed(1) - - -class Test_PostPasses(unittest.TestCase): # noqa: N801 - def get_onnx_model( - self, model, model_desc, inputs, device, _enable_internal_postprocess=True, _extra_postprocess=None - ): - lr_desc = IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ) - model = ORTTrainer( - model, - None, - model_desc, - "LambOptimizer", - map_optimizer_attributes, - lr_desc, - device, - world_rank=0, - world_size=1, - _opset_version=14, - _enable_internal_postprocess=_enable_internal_postprocess, - _extra_postprocess=_extra_postprocess, - ) - - model.train_step(*inputs) - return model.onnx_model_ - - def count_all_nodes(self, model): - return len(model.graph.node) - - def count_nodes(self, model, node_type): - count = 0 - for node in model.graph.node: - if node.op_type == node_type: - count += 1 - return count - - def find_nodes(self, model, node_type): - nodes = [] - for node in model.graph.node: - if node.op_type == node_type: - nodes.append(node) - return nodes - - def get_name(self, name): - if os.path.exists(name): - return name - rel = os.path.join("testdata", name) - if os.path.exists(rel): - return rel - this = os.path.dirname(__file__) - data = os.path.join(this, "..", "..", "..", "..", "onnxruntime", "test", "testdata") - res = os.path.join(data, name) - if os.path.exists(res): - return res - raise FileNotFoundError(f"Unable to find '{name}' or '{rel}' or '{res}'") - - def test_layer_norm(self): - class LayerNormNet(nn.Module): - def __init__(self, target): - super().__init__() - self.ln_1 = nn.LayerNorm(10) - self.loss = nn.CrossEntropyLoss() - self.target = target - - def forward(self, x): - output1 = self.ln_1(x) - loss = self.loss(output1, self.target) - return loss, output1 - - device = torch.device("cpu") - target = torch.ones(20, 10, 10, dtype=torch.int64).to(device) - model = LayerNormNet(target) - input = torch.randn(20, 5, 10, 10, dtype=torch.float32).to(device) - - input_desc = IODescription("input", [], "float32") - output0_desc = IODescription("output0", [], "float32") - output1_desc = IODescription("output1", [20, 5, 10, 10], "float32") - model_desc = ModelDescription([input_desc], [output0_desc, output1_desc]) - - learning_rate = torch.tensor([1.0000000e00]).to(device) - input_args = [input, learning_rate] - - onnx_model = self.get_onnx_model(model, model_desc, input_args, device) - - count_layer_norm = self.count_nodes(onnx_model, "LayerNormalization") - count_nodes = self.count_all_nodes(onnx_model) - - assert count_layer_norm == 0 - assert count_nodes == 3 - - def test_expand(self): - class ExpandNet(nn.Module): - def __init__(self, target): - super().__init__() - self.loss = nn.CrossEntropyLoss() - self.target = target - self.linear = torch.nn.Linear(2, 2) - - def forward(self, x, x1): - output = x.expand_as(x1) - output = self.linear(output) - output = output + output - loss = self.loss(output, self.target) - return loss, output - - device = torch.device("cpu") - target = torch.ones(5, 5, 2, dtype=torch.int64).to(device) - model = ExpandNet(target).to(device) - - x = torch.randn(5, 3, 1, 2, dtype=torch.float32).to(device) - x1 = torch.randn(5, 3, 5, 2, dtype=torch.float32).to(device) - - input0_desc = IODescription("x", [5, 3, 1, 2], "float32") - input1_desc = IODescription("x1", [5, 3, 5, 2], "float32") - output0_desc = IODescription("output0", [], "float32") - output1_desc = IODescription("output1", [5, 3, 5, 2], "float32") - model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc]) - - learning_rate = torch.tensor([1.0000000e00]).to(device) - input_args = [x, x1, learning_rate] - - onnx_model = self.get_onnx_model(model, model_desc, input_args, device) - - # check that expand output has shape - expand_nodes = self.find_nodes(onnx_model, "Expand") - assert len(expand_nodes) == 1 - - model_info = onnx_model.graph.value_info - assert model_info[0].name == expand_nodes[0].output[0] - assert model_info[0].type == onnx_model.graph.input[1].type - - def test_bert(self): - device = torch.device("cpu") - - model_tester = BertModelTest.BertModelTester(self) - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = model_tester.prepare_config_and_inputs() - - model = BertForPreTraining(config=config) - model.eval() - - loss, prediction_scores, seq_relationship_score = model( - input_ids, - attention_mask=input_mask, - token_type_ids=token_type_ids, - masked_lm_labels=token_labels, - next_sentence_label=sequence_labels, - ) - - model_desc = ModelDescription( - [ - model_tester.input_ids_desc, - model_tester.attention_mask_desc, - model_tester.token_type_ids_desc, - model_tester.masked_lm_labels_desc, - model_tester.next_sentence_label_desc, - ], - [model_tester.loss_desc, model_tester.prediction_scores_desc, model_tester.seq_relationship_scores_desc], - ) - - from collections import namedtuple - - MyArgs = namedtuple( - "MyArgs", "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len" - ) - args = MyArgs( - local_rank=0, - world_size=1, - max_steps=100, - learning_rate=0.00001, - warmup_proportion=0.01, - batch_size=13, - seq_len=7, - ) - - dataset_len = 100 - dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device) - learning_rate = torch.tensor(1.0e0, dtype=torch.float32).to(device) - for b in dataloader: - batch = b - break - learning_rate = torch.tensor([1.00e00]).to(device) - inputs = [*batch, learning_rate] - - onnx_model = self.get_onnx_model(model, model_desc, inputs, device, _extra_postprocess=postprocess_model) - - self._bert_helper(onnx_model) - - def _bert_helper(self, onnx_model): - # count layer_norm - count_layer_norm = self.count_nodes(onnx_model, "LayerNormalization") - assert count_layer_norm == 0 - - # get expand node and check output shape - expand_nodes = self.find_nodes(onnx_model, "Expand") - assert len(expand_nodes) == 1 - - model_info = onnx_model.graph.value_info - assert model_info[0].name == expand_nodes[0].output[0] - assert model_info[0].type == onnx_model.graph.input[0].type - - def test_extra_postpass(self): - def postpass_replace_first_add_with_sub(model): - # this post pass replaces the first Add node with Sub in the model. - # Previous graph - # (subgraph 1) (subgraph 2) - # | | - # | | - # |________ ________| - # | | - # Add - # | - # (subgraph 3) - # - # Post graph - # (subgraph 1) (subgraph 2) - # | | - # | | - # |________ ________| - # | | - # Sub - # | - # (subgraph 3) - add_nodes = [n for n in model.graph.node if n.op_type == "Add"] - add_nodes[0].op_type = "Sub" - - class MultiAdd(nn.Module): - def __init__(self, target): - super().__init__() - self.loss = nn.CrossEntropyLoss() - self.target = target - self.linear = torch.nn.Linear(2, 2, bias=False) - - def forward(self, x, x1): - output = x + x1 - output = output + x - output = output + x1 - output = self.linear(output) - loss = self.loss(output, self.target) - return loss, output - - device = torch.device("cpu") - target = torch.ones(5, 2, dtype=torch.int64).to(device) - model = MultiAdd(target).to(device) - - x = torch.randn(5, 5, 2, dtype=torch.float32).to(device) - x1 = torch.randn(5, 5, 2, dtype=torch.float32).to(device) - - input0_desc = IODescription("x", [5, 5, 2], "float32") - input1_desc = IODescription("x1", [5, 5, 2], "float32") - output0_desc = IODescription("output0", [], "float32") - output1_desc = IODescription("output1", [5, 5, 2], "float32") - model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc]) - - learning_rate = torch.tensor([1.0000000e00]).to(device) - input_args = [x, x1, learning_rate] - - onnx_model = self.get_onnx_model( - model, model_desc, input_args, device, _extra_postprocess=postpass_replace_first_add_with_sub - ) - - # check that extra postpass is called, and called only once. - add_nodes = self.find_nodes(onnx_model, "Add") - sub_nodes = self.find_nodes(onnx_model, "Sub") - assert len(add_nodes) == 2 - assert len(sub_nodes) == 1 - - unprocessed_onnx_model = self.get_onnx_model( - model, model_desc, input_args, device, _extra_postprocess=None, _enable_internal_postprocess=False - ) - # check that the model is unchanged. - add_nodes = self.find_nodes(unprocessed_onnx_model, "Add") - sub_nodes = self.find_nodes(unprocessed_onnx_model, "Sub") - assert len(add_nodes) == 3 - assert len(sub_nodes) == 0 - - processed_onnx_model = self.get_onnx_model( - unprocessed_onnx_model, - model_desc, - input_args, - device, - _extra_postprocess=postpass_replace_first_add_with_sub, - ) - # check that extra postpass is called, and called only once. - add_nodes = self.find_nodes(processed_onnx_model, "Add") - sub_nodes = self.find_nodes(processed_onnx_model, "Sub") - assert len(add_nodes) == 2 - assert len(sub_nodes) == 1 - - -if __name__ == "__main__": - unittest.main(module=__name__, buffer=True) diff --git a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py index 0e7e9d23ee627..5341cd053ac18 100644 --- a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py +++ b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py @@ -43,7 +43,7 @@ def run_ortmodule_ops_tests(cwd, log, transformers_cache): env = get_env_with_transformers_cache(transformers_cache) - command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_onnx_ops_ortmodule.py"] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ortmodule_onnx_ops.py"] run_subprocess(command, cwd=cwd, log=log, env=env).check_returncode() @@ -146,7 +146,7 @@ def run_data_sampler_tests(cwd, log): def run_hooks_tests(cwd, log): log.debug("Running: Data hooks tests") - command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_hooks.py"] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ortmodule_hooks.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() diff --git a/orttraining/orttraining/test/python/orttraining_run_bert_pretrain.py b/orttraining/orttraining/test/python/orttraining_run_bert_pretrain.py deleted file mode 100644 index eea733684f140..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_run_bert_pretrain.py +++ /dev/null @@ -1,801 +0,0 @@ -# ================== -import dataclasses -import datetime -import glob -import json -import logging -import os -import random -import shutil -import unittest -from concurrent.futures import ProcessPoolExecutor -from dataclasses import dataclass, field -from typing import Any, Dict, Optional - -import h5py -import numpy as np -import torch -import torch.distributed as dist -from torch.utils.data import DataLoader, Dataset, RandomSampler -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm -from transformers import BertConfig, BertForPreTraining, HfArgumentParser - -import onnxruntime as ort - -# need to override torch.onnx.symbolic_opset12.nll_loss to handle ignore_index == -100 cases. -# the fix for ignore_index == -100 cases is already in pytorch master. -# however to use current torch master is causing computation changes in many tests. -# eventually we will use pytorch with fixed nll_loss once computation -# issues are understood and solved. -import onnxruntime.capi.pt_patch -from onnxruntime.training import amp, optim, orttrainer -from onnxruntime.training.checkpoint import aggregate_checkpoints -from onnxruntime.training.optim import LinearWarmupLRScheduler, PolyWarmupLRScheduler # noqa: F401 - -# we cannot make full convergence run in nightly pipeling because of its timeout limit, -# max_steps is still needed to calculate learning rate. force_to_stop_max_steps is used to -# terminate the training before the pipeline run hit its timeout. -force_to_stop_max_steps = 2500 - -logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO -) -logger = logging.getLogger(__name__) - - -def get_rank(): - if not dist.is_available(): - return 0 - if not dist.is_initialized(): - return 0 - return dist.get_rank() - - -def is_main_process(args): - if hasattr(args, "world_rank"): - return args.world_rank in [-1, 0] - else: - return get_rank() == 0 - - -def bert_model_description(config): - vocab_size = config.vocab_size - new_model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", "max_seq_len_in_batch"], - ), - ( - "token_type_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "masked_lm_labels", - ["batch", "max_seq_len_in_batch"], - ), - ( - "next_sentence_label", - [ - "batch", - ], - ), - ], - "outputs": [ - ("loss", [], True), - ( - "prediction_scores", - ["batch", "max_seq_len_in_batch", vocab_size], - ), - ( - "seq_relationship_scores", - ["batch", 2], - ), - ], - } - return new_model_desc - - -def create_pretraining_dataset(input_file, max_pred_length, args): - train_data = pretraining_dataset(input_file=input_file, max_pred_length=max_pred_length) - train_sampler = RandomSampler(train_data) - train_dataloader = DataLoader( - train_data, sampler=train_sampler, batch_size=args.train_batch_size * args.n_gpu, num_workers=0, pin_memory=True - ) - return train_dataloader, input_file - - -class pretraining_dataset(Dataset): # noqa: N801 - def __init__(self, input_file, max_pred_length): - logger.info("pretraining_dataset: %s, max_pred_length: %d", input_file, max_pred_length) - self.input_file = input_file - self.max_pred_length = max_pred_length - f = h5py.File(input_file, "r") - keys = [ - "input_ids", - "input_mask", - "segment_ids", - "masked_lm_positions", - "masked_lm_ids", - "next_sentence_labels", - ] - self.inputs = [np.asarray(f[key][:]) for key in keys] - f.close() - - def __len__(self): - "Denotes the total number of samples" - return len(self.inputs[0]) - - def __getitem__(self, index): - [input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids, next_sentence_labels] = [ - torch.from_numpy(input[index].astype(np.int64)) - if indice < 5 - else torch.from_numpy(np.asarray(input[index].astype(np.int64))) - for indice, input in enumerate(self.inputs) - ] - - # HF model use default ignore_index value (-100) for CrossEntropyLoss - masked_lm_labels = torch.ones(input_ids.shape, dtype=torch.long) * -100 - index = self.max_pred_length - # store number of masked tokens in index - padded_mask_indices = (masked_lm_positions == 0).nonzero() - if len(padded_mask_indices) != 0: - index = padded_mask_indices[0].item() - masked_lm_labels[masked_lm_positions[:index]] = masked_lm_ids[:index] - return [input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels] - - -import argparse # noqa: E402 - - -def parse_arguments(): - parser = argparse.ArgumentParser() - - # batch size test config parameters - parser.add_argument( - "--enable_mixed_precision", - default=False, - action="store_true", - help="Whether to use 16-bit float precision instead of 32-bit", - ) - - parser.add_argument( - "--sequence_length", - default=512, - type=int, - help="The maximum total input sequence length after WordPiece tokenization. \n" - "Sequences longer than this will be truncated, and sequences shorter \n" - "than this will be padded.", - ) - parser.add_argument( - "--max_predictions_per_seq", default=80, type=int, help="The maximum total of masked tokens in input sequence" - ) - parser.add_argument("--max_batch_size", default=32, type=int, help="Total batch size for training.") - - parser.add_argument("--gelu_recompute", default=False, action="store_true") - - parser.add_argument("--attn_dropout_recompute", default=False, action="store_true") - - parser.add_argument("--transformer_layer_recompute", default=False, action="store_true") - - args = parser.parse_args() - return args - - -@dataclass -class PretrainArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. - """ - - input_dir: str = field( - default=None, metadata={"help": "The input data dir. Should contain .hdf5 files for the task"} - ) - - bert_model: str = field( - default=None, - metadata={ - "help": "Bert pre-trained model selected in the list: bert-base-uncased, \ - bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." - }, - ) - - output_dir: str = field( - default=None, metadata={"help": "The output directory where the model checkpoints will be written."} - ) - - cache_dir: str = field( - default="/tmp/bert_pretrain/", - metadata={"help": "The output directory where the model checkpoints will be written."}, - ) - max_seq_length: Optional[int] = field( - default=512, - metadata={ - "help": "The maximum total input sequence length after tokenization. Sequences longer \ - than this will be truncated, sequences shorter will be padded." - }, - ) - - max_predictions_per_seq: Optional[int] = field( - default=80, metadata={"help": "The maximum total of masked tokens in input sequence."} - ) - - train_batch_size: Optional[int] = field(default=32, metadata={"help": "Batch size for training."}) - - learning_rate: Optional[float] = field(default=5e-5, metadata={"help": "The initial learning rate for Lamb."}) - - num_train_epochs: Optional[float] = field( - default=3.0, metadata={"help": "Total number of training epochs to perform."} - ) - - max_steps: Optional[float] = field(default=1000, metadata={"help": "Total number of training steps to perform."}) - - warmup_proportion: Optional[float] = field( - default=0.01, - metadata={ - "help": "Proportion of training to perform linear learning rate warmup for. \ - E.g., 0.1 = 10%% of training." - }, - ) - - local_rank: Optional[int] = field(default=-1, metadata={"help": "local_rank for distributed training on gpus."}) - - world_rank: Optional[int] = field(default=-1) - - world_size: Optional[int] = field(default=1) - - seed: Optional[int] = field(default=42, metadata={"help": "random seed for initialization."}) - - gradient_accumulation_steps: Optional[int] = field( - default=1, metadata={"help": "Number of updates steps to accumualte before performing a backward/update pass."} - ) - - fp16: bool = field(default=False, metadata={"help": "Whether to use 16-bit float precision instead of 32-bit."}) - - gelu_recompute: bool = field( - default=False, metadata={"help": "Whether to enable recomputing Gelu activation output to save memory."} - ) - attn_dropout_recompute: bool = field( - default=False, metadata={"help": "Whether to enable recomputing attention dropout to save memory."} - ) - transformer_layer_recompute: bool = field( - default=False, metadata={"help": "Whether to enable recomputing transformer layerwise to save memory."} - ) - - loss_scale: Optional[float] = field( - default=0.0, metadata={"help": "Loss scaling, positive power of 2 values can improve fp16 convergence."} - ) - - deepspeed_zero_stage: Optional[int] = field(default=0, metadata={"help": "Deepspeed Zero Stage. 0 => disabled"}) - - log_freq: Optional[float] = field(default=1.0, metadata={"help": "frequency of logging loss."}) - - checkpoint_activations: bool = field(default=False, metadata={"help": "Whether to use gradient checkpointing."}) - - resume_from_checkpoint: bool = field( - default=False, metadata={"help": "Whether to resume training from checkpoint."} - ) - - resume_step: Optional[int] = field(default=-1, metadata={"help": "Step to resume training from."}) - - num_steps_per_checkpoint: Optional[int] = field( - default=100, metadata={"help": "Number of update steps until a model checkpoint is saved to disk."} - ) - - save_checkpoint: Optional[bool] = field( - default=False, metadata={"help": "Enable for saving a model checkpoint to disk."} - ) - - init_state_dict: Optional[dict] = field(default=None, metadata={"help": "State to load before training."}) - - phase2: bool = field(default=False, metadata={"help": "Whether to train with seq len 512."}) - - allreduce_post_accumulation: bool = field( - default=False, metadata={"help": "Whether to do allreduces during gradient accumulation steps."} - ) - - allreduce_post_accumulation_fp16: bool = field( - default=False, metadata={"help": "Whether to do fp16 allreduce post accumulation."} - ) - - accumulate_into_fp16: bool = field(default=False, metadata={"help": "Whether to use fp16 gradient accumulators."}) - - phase1_end_step: Optional[int] = field( - default=7038, metadata={"help": "Whether to use fp16 gradient accumulators."} - ) - - tensorboard_dir: Optional[str] = field( - default=None, - ) - - schedule: Optional[str] = field( - default="warmup_poly", - ) - - # this argument is test specific. to run a full bert model will take too long to run. instead, we reduce - # number of hidden layers so that it can show convergence to an extend to help detect any regression. - force_num_hidden_layers: Optional[int] = field( - default=None, metadata={"help": "Whether to use fp16 gradient accumulators."} - ) - - def to_json_string(self): - """ - Serializes this instance to a JSON string. - """ - return json.dumps(dataclasses.asdict(self), indent=2) - - def to_sanitized_dict(self) -> Dict[str, Any]: - """ - Sanitized serialization to use with TensorBoard`s hparams - """ - d = dataclasses.asdict(self) - valid_types = [bool, int, float, str, torch.Tensor] - return {k: v if type(v) in valid_types else str(v) for k, v in d.items()} - - -def setup_training(args): - assert torch.cuda.is_available() - - if args.local_rank == -1: - args.local_rank = 0 - args.world_rank = 0 - - print("args.local_rank: ", args.local_rank) - torch.cuda.set_device(args.local_rank) - device = torch.device("cuda", args.local_rank) - args.n_gpu = 1 - - if args.gradient_accumulation_steps < 1: - raise ValueError( - f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps}, should be >= 1" - ) - if args.train_batch_size % args.gradient_accumulation_steps != 0: - raise ValueError( - "Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible".format( - args.gradient_accumulation_steps, args.train_batch_size - ) - ) - - # args.train_batch_size is per global step (optimization step) batch size - # now make it a per gpu batch size - args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps - args.train_batch_size = args.train_batch_size // args.world_size - - logger.info("setup_training: args.train_batch_size = %d", args.train_batch_size) - return device, args - - -def setup_torch_distributed(world_rank, world_size): - os.environ["RANK"] = str(world_rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12345" - torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=world_rank) - return - - -def prepare_model(args, device): - config = BertConfig.from_pretrained(args.bert_model, cache_dir=args.cache_dir) - - # config.num_hidden_layers = 12 - if args.force_num_hidden_layers: - logger.info("Modifying model config with num_hidden_layers to %d", args.force_num_hidden_layers) - config.num_hidden_layers = args.force_num_hidden_layers - - model = BertForPreTraining(config) - if args.init_state_dict is not None: - model.load_state_dict(args.init_state_dict) - model_desc = bert_model_description(config) - - lr_scheduler = LinearWarmupLRScheduler(total_steps=int(args.max_steps), warmup=args.warmup_proportion) - - loss_scaler = amp.DynamicLossScaler() if args.fp16 else None - - options = orttrainer.ORTTrainerOptions( - { - "batch": {"gradient_accumulation_steps": args.gradient_accumulation_steps}, - "device": {"id": str(device)}, - "mixed_precision": {"enabled": args.fp16, "loss_scaler": loss_scaler}, - "graph_transformer": { - "attn_dropout_recompute": args.attn_dropout_recompute, - "gelu_recompute": args.gelu_recompute, - "transformer_layer_recompute": args.transformer_layer_recompute, - }, - "debug": { - "deterministic_compute": True, - }, - "utils": {"grad_norm_clip": True}, - "distributed": { - "world_rank": max(0, args.local_rank), - "world_size": args.world_size, - "local_rank": max(0, args.local_rank), - "allreduce_post_accumulation": args.allreduce_post_accumulation, - "deepspeed_zero_optimization": {"stage": args.deepspeed_zero_stage}, - "enable_adasum": False, - }, - "lr_scheduler": lr_scheduler, - } - ) - - param_optimizer = list(model.named_parameters()) - no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] - params = [ - { - "params": [n for n, p in param_optimizer if any(no_decay_key in n for no_decay_key in no_decay_keys)], - "alpha": 0.9, - "beta": 0.999, - "lambda": 0.0, - "epsilon": 1e-6, - }, - { - "params": [n for n, p in param_optimizer if not any(no_decay_key in n for no_decay_key in no_decay_keys)], - "alpha": 0.9, - "beta": 0.999, - "lambda": 0.0, - "epsilon": 1e-6, - }, - ] - - optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) - model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=options) - - return model - - -def get_data_file(f_id, world_rank, world_size, files): - num_files = len(files) - if world_size > num_files: - remainder = world_size % num_files - return files[(f_id * world_size + world_rank + remainder * f_id) % num_files] - elif world_size > 1: - return files[(f_id * world_size + world_rank) % num_files] - else: - return files[f_id % num_files] - - -def main(): - parser = HfArgumentParser(PretrainArguments) - args = parser.parse_args_into_dataclasses()[0] - do_pretrain(args) - - -def do_pretrain(args): - if is_main_process(args) and args.tensorboard_dir: - tb_writer = SummaryWriter(log_dir=args.tensorboard_dir) - tb_writer.add_text("args", args.to_json_string()) - tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={}) - else: - tb_writer = None - - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - ort.set_seed(args.seed) - - device, args = setup_training(args) - - model = prepare_model(args, device) - - logger.info("Running training: Batch size = %d, initial LR = %f", args.train_batch_size, args.learning_rate) - - average_loss = 0.0 - epoch = 0 - training_steps = 0 - - pool = ProcessPoolExecutor(1) - while True: - files = [ - os.path.join(args.input_dir, f) - for f in os.listdir(args.input_dir) - if os.path.isfile(os.path.join(args.input_dir, f)) and "training" in f - ] - files.sort() - random.shuffle(files) - - f_id = 0 - train_dataloader, data_file = create_pretraining_dataset( - get_data_file(f_id, args.world_rank, args.world_size, files), args.max_predictions_per_seq, args - ) - - for f_id in range(1, len(files)): - logger.info("data file %s" % (data_file)) - - dataset_future = pool.submit( - create_pretraining_dataset, - get_data_file(f_id, args.world_rank, args.world_size, files), - args.max_predictions_per_seq, - args, - ) - - train_iter = tqdm(train_dataloader, desc="Iteration") if is_main_process(args) else train_dataloader - for _step, batch in enumerate(train_iter): - training_steps += 1 - batch = [t.to(device) for t in batch] # noqa: PLW2901 - input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch - - loss, _, _ = model.train_step( - input_ids, input_mask, segment_ids, masked_lm_labels, next_sentence_labels - ) - average_loss += loss.item() - - global_step = model._train_step_info.optimization_step - if training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0: - if is_main_process(args): - divisor = args.log_freq * args.gradient_accumulation_steps - if tb_writer: - lr = model.options.lr_scheduler.get_last_lr()[0] - tb_writer.add_scalar("train/summary/scalar/Learning_Rate", lr, global_step) - if args.fp16: - tb_writer.add_scalar("train/summary/scalar/loss_scale_25", loss, global_step) - # TODO: ORTTrainer to expose all_finite - # tb_writer.add_scalar('train/summary/scalar/all_fp16_gradients_finite_859', all_finite, global_step) - tb_writer.add_scalar("train/summary/total_loss", average_loss / divisor, global_step) - - print(f"Step:{global_step} Average Loss = {average_loss / divisor}") - - if global_step >= args.max_steps or global_step >= force_to_stop_max_steps: - if tb_writer: - tb_writer.close() - - if global_step >= args.max_steps: - if args.save_checkpoint: - model.save_checkpoint(os.path.join(args.output_dir, f"checkpoint-{args.world_rank}.ortcp")) - final_loss = average_loss / (args.log_freq * args.gradient_accumulation_steps) - return final_loss - - average_loss = 0 - - del train_dataloader - - train_dataloader, data_file = dataset_future.result(timeout=None) - - epoch += 1 - - -def generate_tensorboard_logdir(root_dir): - current_date_time = datetime.datetime.today() - - dt_string = current_date_time.strftime("BERT_pretrain_%y_%m_%d_%I_%M_%S") - return os.path.join(root_dir, dt_string) - - -class ORTBertPretrainTest(unittest.TestCase): - def setUp(self): - self.output_dir = "/bert_data/hf_data/test_out/bert_pretrain_results" - self.bert_model = "bert-base-uncased" - self.local_rank = -1 - self.world_rank = -1 - self.world_size = 1 - self.max_steps = 300000 - self.learning_rate = 5e-4 - self.max_seq_length = 512 - self.max_predictions_per_seq = 20 - self.input_dir = "/bert_data/hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train" - self.train_batch_size = 4096 - self.gradient_accumulation_steps = 64 - self.fp16 = True - self.allreduce_post_accumulation = True - self.tensorboard_dir = "/bert_data/hf_data/test_out" - - def test_pretrain_throughput(self, process_args=None): - if process_args.sequence_length == 128: - input_dir = "/bert_data/hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train" - else: - input_dir = "/bert_data/hdf5_lower_case_1_seq_len_512_max_pred_80_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train" - - print("process_args.enable_mixed_precision: ", process_args.enable_mixed_precision) - print("process_args.sequence_length: ", process_args.sequence_length) - print("process_args.max_batch_size: ", process_args.max_batch_size) - print("process_args.max_predictions_per_seq: ", process_args.max_predictions_per_seq) - print("process_args.gelu_recompute: ", process_args.gelu_recompute) - print("process_args.attn_dropout_recompute: ", process_args.attn_dropout_recompute) - print("process_args.transformer_layer_recompute: ", process_args.transformer_layer_recompute) - - args = PretrainArguments( - input_dir=input_dir, - output_dir="/bert_data/hf_data/test_out/bert_pretrain_results", - bert_model="bert-large-uncased", - local_rank=self.local_rank, - world_rank=self.world_rank, - world_size=self.world_size, - max_steps=10, - learning_rate=5e-4, - max_seq_length=process_args.sequence_length, - max_predictions_per_seq=process_args.max_predictions_per_seq, - train_batch_size=process_args.max_batch_size, - gradient_accumulation_steps=1, - fp16=process_args.enable_mixed_precision, - gelu_recompute=process_args.gelu_recompute, - attn_dropout_recompute=process_args.attn_dropout_recompute, - transformer_layer_recompute=process_args.transformer_layer_recompute, - allreduce_post_accumulation=True, - # TODO: remove - force_num_hidden_layers=2, - ) - do_pretrain(args) - - def test_pretrain_convergence(self): - args = PretrainArguments( - output_dir=self.output_dir, - bert_model=self.bert_model, - local_rank=self.local_rank, - world_rank=self.world_rank, - world_size=self.world_size, - max_steps=self.max_steps, - learning_rate=self.learning_rate, - max_seq_length=self.max_seq_length, - max_predictions_per_seq=self.max_predictions_per_seq, - train_batch_size=self.train_batch_size, - gradient_accumulation_steps=self.gradient_accumulation_steps, - input_dir=self.input_dir, - fp16=self.fp16, - allreduce_post_accumulation=self.allreduce_post_accumulation, - force_num_hidden_layers=self.force_num_hidden_layers, - tensorboard_dir=generate_tensorboard_logdir("/bert_data/hf_data/test_out/"), - ) - final_loss = do_pretrain(args) - return final_loss - - def test_pretrain_zero(self): - assert self.world_size > 0, "ZeRO test requires a distributed run." - setup_torch_distributed(self.world_rank, self.world_size) - per_gpu_batch_size = 32 - optimization_batch_size = per_gpu_batch_size * self.world_size # set to disable grad accumulation - - self.train_batch_size = optimization_batch_size - self.gradient_accumulation_steps = 1 - self.deepspeed_zero_stage = 1 - self.force_num_hidden_layers = 2 - self.max_seq_length = 32 - self.output_dir = "./bert_pretrain_ckpt" - if self.world_rank == 0: - if os.path.isdir(self.output_dir): - shutil.rmtree(self.output_dir) - os.makedirs(self.output_dir, exist_ok=True) - - torch.distributed.barrier() - - assert os.path.exists(self.output_dir) - - # run a few optimization steps - self.max_steps = 200 - args = PretrainArguments( - output_dir=self.output_dir, - bert_model=self.bert_model, - local_rank=self.local_rank, - world_rank=self.world_rank, - world_size=self.world_size, - max_steps=self.max_steps, - learning_rate=self.learning_rate, - max_seq_length=self.max_seq_length, - max_predictions_per_seq=self.max_predictions_per_seq, - train_batch_size=self.train_batch_size, - gradient_accumulation_steps=self.gradient_accumulation_steps, - input_dir=self.input_dir, - fp16=self.fp16, - allreduce_post_accumulation=self.allreduce_post_accumulation, - force_num_hidden_layers=self.force_num_hidden_layers, - deepspeed_zero_stage=self.deepspeed_zero_stage, - save_checkpoint=True, - ) - do_pretrain(args) - - # ensure all workers reach this point before loading the checkpointed state - torch.distributed.barrier() - - # on rank 0, load the trained state - if args.world_rank == 0: - checkpoint_files = glob.glob(os.path.join(self.output_dir, "checkpoint*.ortcp")) - args.init_state_dict = aggregate_checkpoints(checkpoint_files, pytorch_format=True) - - torch.distributed.barrier() - - # run a single step to get the loss, on rank 0 should be lesser than starting loss - args.save_checkpoint = False - args.max_steps = 1 - args.deepspeed_zero_stage = 0 - final_loss = do_pretrain(args) - return final_loss - - -if __name__ == "__main__": - import sys - - logger.warning("sys.argv: %s", sys.argv) - # usage: - # data parallel training - # mpirun -n 4 python orttraining_run_bert_pretrain.py - # - # single gpu: - # python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_throughput - # [batch size test arguments] - # python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_convergence - # - # pytorch.distributed.launch will not work because ort backend requires MPI to broadcast ncclUniqueId - # calling unpublished get_mpi_context_xxx to get rank/size numbers. - try: - # In case ORT is not built with MPI/NCCL, there are no get_mpi_context_xxx internal apis. - from onnxruntime.capi._pybind_state import get_mpi_context_local_size # noqa: F401 - from onnxruntime.capi._pybind_state import get_mpi_context_world_rank # noqa: F401 - from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_world_size - - has_get_mpi_context_internal_api = True - except ImportError: - has_get_mpi_context_internal_api = False - pass - if has_get_mpi_context_internal_api and get_mpi_context_world_size() > 1: - world_size = get_mpi_context_world_size() - print("get_mpi_context_world_size(): ", world_size) - local_rank = get_mpi_context_local_rank() - - if local_rank == 0: - print("================================================================> os.getpid() = ", os.getpid()) - - test = ORTBertPretrainTest() - test.setUp() - test.local_rank = local_rank - test.world_rank = local_rank - test.world_size = world_size - - if len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_zero": - logger.info("running ORTBertPretrainTest.test_pretrain_zero()...") - final_loss = test.test_pretrain_zero() - logger.info("ORTBertPretrainTest.test_pretrain_zero() rank = %i final loss = %f", local_rank, final_loss) - if local_rank == 0: - test.assertLess(final_loss, 10.2) - else: - test.assertGreater(final_loss, 11.0) - logger.info("ORTBertPretrainTest.test_pretrain_zero() passed") - elif len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_convergence": - logger.info("running ORTBertPretrainTest.test_pretrain_convergence()...") - test.max_steps = 200 - test.force_num_hidden_layers = 8 - final_loss = test.test_pretrain_convergence() - logger.info("ORTBertPretrainTest.test_pretrain_convergence() final loss = %f", final_loss) - test.assertLess(final_loss, 8.5) - logger.info("ORTBertPretrainTest.test_pretrain_convergence() passed") - else: - # https://microsoft.sharepoint.com/teams/ONNX2/_layouts/15/Doc.aspx?sourcedoc={170774be-e1c6-4f8b-a3ae-984f211fe410}&action=edit&wd=target%28ONNX%20Training.one%7C8176133b-c7cb-4ef2-aa9d-3fdad5344c40%2FGitHub%20Master%20Merge%20Schedule%7Cb67f0db1-e3a0-4add-80a6-621d67fd8107%2F%29 - # to make equivalent args for cpp convergence test - test.max_seq_length = 128 - test.max_predictions_per_seq = 20 - test.gradient_accumulation_steps = 16 - - # cpp_batch_size (=64) * grad_acc * world_size - test.train_batch_size = 64 * test.gradient_accumulation_steps * test.world_size - test.max_steps = 300000 - - test.force_num_hidden_layers = None - - # already using Adam (e.g. AdamConfig) - test.learning_rate = 5e-4 - test.warmup_proportion = 0.1 - - final_loss = test.test_pretrain_convergence() - logger.info("ORTBertPretrainTest.test_pretrain_convergence() final loss = %f", final_loss) - else: - # unittest does not accept user defined arguments - # we need to run this script with user defined arguments - if len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_throughput": - run_test_pretrain_throughput, run_test_pretrain_convergence = True, False - sys.argv.remove("ORTBertPretrainTest.test_pretrain_throughput") - elif len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_convergence": - run_test_pretrain_throughput, run_test_pretrain_convergence = False, True - sys.argv.remove("ORTBertPretrainTest.test_pretrain_convergence") - else: - run_test_pretrain_throughput, run_test_pretrain_convergence = True, True - process_args = parse_arguments() - test = ORTBertPretrainTest() - test.setUp() - - if run_test_pretrain_throughput: - logger.info("running single GPU ORTBertPretrainTest.test_pretrain_throughput()...") - test.test_pretrain_throughput(process_args) - logger.info("single GPU ORTBertPretrainTest.test_pretrain_throughput() passed") - - # unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_run_frontend_batch_size_test.py b/orttraining/orttraining/test/python/orttraining_run_frontend_batch_size_test.py deleted file mode 100644 index 3e2d1a7154bfd..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_run_frontend_batch_size_test.py +++ /dev/null @@ -1,67 +0,0 @@ -import collections -import subprocess -import sys - -Config = collections.namedtuple( - "Config", - [ - "enable_mixed_precision", - "sequence_length", - "max_batch_size", - "max_predictions_per_seq", - "gelu_recompute", - "attn_dropout_recompute", - "transformer_layer_recompute", - ], -) - -configs = [ - Config(True, 128, 46, 20, False, False, False), - Config(True, 512, 8, 80, False, False, False), - Config(False, 128, 26, 20, False, False, False), - Config(False, 512, 4, 80, False, False, False), - Config(True, 128, 50, 20, True, False, False), - Config(True, 128, 50, 20, False, True, False), - Config(True, 128, 76, 20, False, False, True), - Config(True, 512, 8, 80, True, False, False), - Config(True, 512, 9, 80, False, True, False), - Config(True, 512, 15, 80, False, False, True), -] - - -def run_with_config(config): - print( - "##### testing name - {}-{} #####".format( - "fp16" if config.enable_mixed_precision else "fp32", config.sequence_length - ) - ) - print("gelu_recompute: ", config.gelu_recompute) - print("attn_dropout_recompute: ", config.attn_dropout_recompute) - print("transformer_layer_recompute: ", config.transformer_layer_recompute) - - cmds = [ - sys.executable, - "orttraining_run_bert_pretrain.py", - "ORTBertPretrainTest.test_pretrain_throughput", - "--sequence_length", - str(config.sequence_length), - "--max_batch_size", - str(config.max_batch_size), - "--max_predictions_per_seq", - str(config.max_predictions_per_seq), - ] - if config.enable_mixed_precision: - cmds.append("--enable_mixed_precision") - if config.gelu_recompute: - cmds.append("--gelu_recompute") - if config.attn_dropout_recompute: - cmds.append("--attn_dropout_recompute") - if config.transformer_layer_recompute: - cmds.append("--transformer_layer_recompute") - - # access to azure storage shared disk is much slower so we need a longer timeout. - subprocess.run(cmds, timeout=1200).check_returncode() # noqa: PLW1510 - - -for config in configs: - run_with_config(config) diff --git a/orttraining/orttraining/test/python/orttraining_run_glue.py b/orttraining/orttraining/test/python/orttraining_run_glue.py deleted file mode 100644 index 794e2f8cc7240..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_run_glue.py +++ /dev/null @@ -1,323 +0,0 @@ -# adapted from run_glue.py of huggingface transformers - -import dataclasses # noqa: F401 -import logging -import os -import unittest -from dataclasses import dataclass, field -from typing import Dict, Optional - -import numpy as np -from numpy.testing import assert_allclose -from transformers import ( - AutoConfig, - AutoModelForSequenceClassification, - AutoTokenizer, - EvalPrediction, - GlueDataset, - GlueDataTrainingArguments, - TrainingArguments, - glue_compute_metrics, - glue_output_modes, - glue_tasks_num_labels, - set_seed, -) - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer # noqa: F401 - -try: - from onnxruntime.capi._pybind_state import get_mpi_context_local_size # noqa: F401 - from onnxruntime.capi._pybind_state import get_mpi_context_world_rank # noqa: F401 - from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_world_size - - has_get_mpi_context_internal_api = True -except ImportError: - has_get_mpi_context_internal_api = False - pass - - -import torch # noqa: F401 -from orttraining_transformer_trainer import ORTTransformerTrainer - -logger = logging.getLogger(__name__) - - -def verify_old_and_new_api_are_equal(results_per_api): - new_api_results = results_per_api[True] - old_api_results = results_per_api[False] - for key in new_api_results: - assert_allclose(new_api_results[key], old_api_results[key]) - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. - """ - - model_name_or_path: str = field(metadata={"help": "model identifier from huggingface.co/models"}) - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - cache_dir: Optional[str] = field( - default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} - ) - - -class ORTGlueTest(unittest.TestCase): - def setUp(self): - # configurations not to be changed accoss tests - self.max_seq_length = 128 - self.train_batch_size = 8 - self.learning_rate = 2e-5 - self.num_train_epochs = 3.0 - self.local_rank = -1 - self.world_size = 1 - self.overwrite_output_dir = True - self.gradient_accumulation_steps = 1 - self.data_dir = "/bert_data/hf_data/glue_data/" - self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "glue_test_output/") - self.cache_dir = "/tmp/glue/" - self.logging_steps = 10 - - def test_roberta_with_mrpc(self): - expected_acc = 0.85 - expected_f1 = 0.88 - expected_loss = 0.35 - results = self.run_glue(model_name="roberta-base", task_name="MRPC", fp16=False) - - assert results["acc"] >= expected_acc - assert results["f1"] >= expected_f1 - assert results["loss"] <= expected_loss - - def test_roberta_fp16_with_mrpc(self): - expected_acc = 0.87 - expected_f1 = 0.90 - expected_loss = 0.33 - - results = self.run_glue(model_name="roberta-base", task_name="MRPC", fp16=True) - - assert results["acc"] >= expected_acc - assert results["f1"] >= expected_f1 - assert results["loss"] <= expected_loss - - def test_bert_with_mrpc(self): - if self.local_rank == -1: - expected_acc = 0.83 - expected_f1 = 0.88 - expected_loss = 0.44 - elif self.local_rank == 0: - expected_acc = 0.81 - expected_f1 = 0.86 - expected_loss = 0.44 - - results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=False) - - if self.local_rank in [-1, 0]: - assert results["acc"] >= expected_acc - assert results["f1"] >= expected_f1 - assert results["loss"] <= expected_loss - - def test_bert_fp16_with_mrpc(self): - expected_acc = 0.84 - expected_f1 = 0.88 - expected_loss = 0.46 - - results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=True) - - assert results["acc"] >= expected_acc - assert results["f1"] >= expected_f1 - assert results["loss"] <= expected_loss - - def model_to_desc(self, model_name, model): - if model_name.startswith("bert") or model_name.startswith("xlnet"): - model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", "max_seq_len_in_batch"], - ), - ( - "token_type_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "labels", - [ - "batch", - ], - ), - ], - "outputs": [("loss", [], True), ("logits", ["batch", 2])], - } - elif model_name.startswith("roberta"): - model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", "max_seq_len_in_batch"], - ), - ( - "labels", - [ - "batch", - ], - ), - ], - "outputs": [("loss", [], True), ("logits", ["batch", 2])], - } - else: - raise RuntimeError(f"unsupported base model name {model_name}.") - - return model_desc - - def run_glue(self, model_name, task_name, fp16): - model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir) - data_args = GlueDataTrainingArguments( - task_name=task_name, data_dir=os.path.join(self.data_dir, task_name), max_seq_length=self.max_seq_length - ) - - training_args = TrainingArguments( - output_dir=os.path.join(self.output_dir, task_name), - do_train=True, - do_eval=True, - per_gpu_train_batch_size=self.train_batch_size, - learning_rate=self.learning_rate, - num_train_epochs=self.num_train_epochs, - local_rank=self.local_rank, - overwrite_output_dir=self.overwrite_output_dir, - gradient_accumulation_steps=self.gradient_accumulation_steps, - fp16=fp16, - logging_steps=self.logging_steps, - ) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, - ) - logger.warning( - "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", - training_args.local_rank, - training_args.device, - training_args.n_gpu, - bool(training_args.local_rank != -1), - training_args.fp16, - ) - logger.info("Training/evaluation parameters %s", training_args) - - set_seed(training_args.seed) - onnxruntime.set_seed(training_args.seed) - - try: - num_labels = glue_tasks_num_labels[data_args.task_name] - output_mode = glue_output_modes[data_args.task_name] - except KeyError: - raise ValueError("Task not found: %s" % (data_args.task_name)) # noqa: B904 - - config = AutoConfig.from_pretrained( - model_args.config_name if model_args.config_name else model_args.model_name_or_path, - num_labels=num_labels, - finetuning_task=data_args.task_name, - cache_dir=model_args.cache_dir, - ) - tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - ) - - model = AutoModelForSequenceClassification.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - ) - - train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None - - eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") if training_args.do_eval else None - - def compute_metrics(p: EvalPrediction) -> Dict: - if output_mode == "classification": - preds = np.argmax(p.predictions, axis=1) - elif output_mode == "regression": - preds = np.squeeze(p.predictions) - return glue_compute_metrics(data_args.task_name, preds, p.label_ids) - - model_desc = self.model_to_desc(model_name, model) - # Initialize the ORTTrainer within ORTTransformerTrainer - trainer = ORTTransformerTrainer( - model=model, - model_desc=model_desc, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - compute_metrics=compute_metrics, - world_size=self.world_size, - ) - - # Training - if training_args.do_train: - trainer.train() - trainer.save_model() - - # Evaluation - results = {} - if training_args.do_eval and training_args.local_rank in [-1, 0]: - logger.info("*** Evaluate ***") - - result = trainer.evaluate() - - logger.info(f"***** Eval results {data_args.task_name} *****") - for key, value in result.items(): - logger.info(" %s = %s", key, value) - - results.update(result) - - return results - - -if __name__ == "__main__": - if has_get_mpi_context_internal_api: - local_rank = get_mpi_context_local_rank() - world_size = get_mpi_context_world_size() - else: - local_rank = -1 - world_size = 1 - - if world_size > 1: - # mpi launch - logger.warning("mpirun launch, local_rank / world_size: %s : % s", local_rank, world_size) - - # TrainingArguments._setup_devices will call torch.distributed.init_process_group(backend="nccl") - # pytorch expects following environment settings (which would be set if launched with torch.distributed.launch). - - os.environ["RANK"] = str(local_rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = "29500" - - from onnxruntime.capi._pybind_state import set_cuda_device_id - - set_cuda_device_id(local_rank) - - test = ORTGlueTest() - test.setUp() - test.local_rank = local_rank - test.world_size = world_size - test.test_bert_with_mrpc() - else: - unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py b/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py deleted file mode 100644 index 92db204593bcd..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py +++ /dev/null @@ -1,281 +0,0 @@ -# adapted from run_multiple_choice.py of huggingface transformers -# https://github.com/huggingface/transformers/blob/master/examples/multiple-choice/run_multiple_choice.py - -import dataclasses # noqa: F401 -import logging -import os -import unittest -from dataclasses import dataclass, field -from typing import Dict, Optional - -import numpy as np -import torch # noqa: F401 -from numpy.testing import assert_allclose # noqa: F401 -from orttraining_run_glue import verify_old_and_new_api_are_equal # noqa: F401 -from orttraining_transformer_trainer import ORTTransformerTrainer -from transformers import HfArgumentParser # noqa: F401 -from transformers import Trainer # noqa: F401 -from transformers import ( - AutoConfig, - AutoModelForMultipleChoice, - AutoTokenizer, - EvalPrediction, - TrainingArguments, - set_seed, -) -from utils_multiple_choice import MultipleChoiceDataset, Split, SwagProcessor - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer # noqa: F401 - -logger = logging.getLogger(__name__) - - -def simple_accuracy(preds, labels): - return (preds == labels).mean() - - -@dataclass -class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. - """ - - model_name_or_path: str = field(metadata={"help": "model identifier from huggingface.co/models"}) - config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} - ) - tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} - ) - cache_dir: Optional[str] = field( - default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} - ) - - -@dataclass -class DataTrainingArguments: - """ - Arguments pertaining to what data we are going to input our model for training and eval. - """ - - task_name: str = field(metadata={"help": "The name of the task to train on."}) - data_dir: str = field(metadata={"help": "Should contain the data files for the task."}) - max_seq_length: int = field( - default=128, - metadata={ - "help": "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded." - }, - ) - overwrite_cache: bool = field(default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}) - - -class ORTMultipleChoiceTest(unittest.TestCase): - def setUp(self): - # configurations not to be changed accoss tests - self.max_seq_length = 80 - self.train_batch_size = 16 - self.eval_batch_size = 2 - self.learning_rate = 2e-5 - self.num_train_epochs = 1.0 - self.local_rank = -1 - self.overwrite_output_dir = True - self.gradient_accumulation_steps = 8 - self.data_dir = "/bert_data/hf_data/swag/swagaf/data" - self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "multiple_choice_test_output/") - self.cache_dir = "/tmp/multiple_choice/" - self.logging_steps = 10 - self.rtol = 2e-01 - - def test_bert_with_swag(self): - expected_acc = 0.75 - expected_loss = 0.64 - - results = self.run_multiple_choice(model_name="bert-base-cased", task_name="swag", fp16=False) - assert results["acc"] >= expected_acc - assert results["loss"] <= expected_loss - - def test_bert_fp16_with_swag(self): - # larger batch can be handled with mixed precision - self.train_batch_size = 32 - - expected_acc = 0.73 - expected_loss = 0.68 - - results = self.run_multiple_choice(model_name="bert-base-cased", task_name="swag", fp16=True) - assert results["acc"] >= expected_acc - assert results["loss"] <= expected_loss - - def run_multiple_choice(self, model_name, task_name, fp16): - model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir) - data_args = DataTrainingArguments( - task_name=task_name, data_dir=self.data_dir, max_seq_length=self.max_seq_length - ) - - training_args = TrainingArguments( - output_dir=os.path.join(self.output_dir, task_name), - do_train=True, - do_eval=True, - per_gpu_train_batch_size=self.train_batch_size, - per_gpu_eval_batch_size=self.eval_batch_size, - learning_rate=self.learning_rate, - num_train_epochs=self.num_train_epochs, - local_rank=self.local_rank, - overwrite_output_dir=self.overwrite_output_dir, - gradient_accumulation_steps=self.gradient_accumulation_steps, - fp16=fp16, - logging_steps=self.logging_steps, - ) - - # Setup logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, - ) - logger.warning( - "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", - training_args.local_rank, - training_args.device, - training_args.n_gpu, - bool(training_args.local_rank != -1), - training_args.fp16, - ) - logger.info("Training/evaluation parameters %s", training_args) - - set_seed(training_args.seed) - onnxruntime.set_seed(training_args.seed) - - try: - processor = SwagProcessor() - label_list = processor.get_labels() - num_labels = len(label_list) - except KeyError: - raise ValueError("Task not found: %s" % (data_args.task_name)) # noqa: B904 - - config = AutoConfig.from_pretrained( - model_args.config_name if model_args.config_name else model_args.model_name_or_path, - num_labels=num_labels, - finetuning_task=data_args.task_name, - cache_dir=model_args.cache_dir, - ) - - tokenizer = AutoTokenizer.from_pretrained( - model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, - cache_dir=model_args.cache_dir, - ) - - model = AutoModelForMultipleChoice.from_pretrained( - model_args.model_name_or_path, - from_tf=bool(".ckpt" in model_args.model_name_or_path), - config=config, - cache_dir=model_args.cache_dir, - ) - - # Get datasets - train_dataset = ( - MultipleChoiceDataset( - data_dir=data_args.data_dir, - tokenizer=tokenizer, - task=data_args.task_name, - processor=processor, - max_seq_length=data_args.max_seq_length, - overwrite_cache=data_args.overwrite_cache, - mode=Split.train, - ) - if training_args.do_train - else None - ) - eval_dataset = ( - MultipleChoiceDataset( - data_dir=data_args.data_dir, - tokenizer=tokenizer, - task=data_args.task_name, - processor=processor, - max_seq_length=data_args.max_seq_length, - overwrite_cache=data_args.overwrite_cache, - mode=Split.dev, - ) - if training_args.do_eval - else None - ) - - def compute_metrics(p: EvalPrediction) -> Dict: - preds = np.argmax(p.predictions, axis=1) - return {"acc": simple_accuracy(preds, p.label_ids)} - - if model_name.startswith("bert"): - model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "token_type_ids", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "labels", - ["batch", num_labels], - ), - ], - "outputs": [("loss", [], True), ("reshaped_logits", ["batch", num_labels])], - } - else: - model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", num_labels, "max_seq_len_in_batch"], - ), - ( - "labels", - ["batch", num_labels], - ), - ], - "outputs": [("loss", [], True), ("reshaped_logits", ["batch", num_labels])], - } - - # Initialize the ORTTrainer within ORTTransformerTrainer - trainer = ORTTransformerTrainer( - model=model, - model_desc=model_desc, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - compute_metrics=compute_metrics, - ) - - # Training - if training_args.do_train: - trainer.train() - trainer.save_model() - - # Evaluation - results = {} - if training_args.do_eval and training_args.local_rank in [-1, 0]: - logger.info("*** Evaluate ***") - - result = trainer.evaluate() - - logger.info(f"***** Eval results {data_args.task_name} *****") - for key, value in result.items(): - logger.info(" %s = %s", key, value) - - results.update(result) - - return results - - -if __name__ == "__main__": - unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py b/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py deleted file mode 100644 index 71e6bb8e4d2f2..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py +++ /dev/null @@ -1,6 +0,0 @@ -from orttraining_test_layer_norm_transform import layer_norm_transform # noqa: F401 -from orttraining_test_model_transform import add_expand_shape, add_name, fix_transpose # noqa: F401 - - -def postprocess_model(model): - add_name(model) diff --git a/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py b/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py deleted file mode 100644 index 21372caaf6779..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# orttraining_test_checkpoint_storage.py - -import os -import pickle -import shutil - -import numpy as np -import pytest -import torch - -from onnxruntime.training import _checkpoint_storage - -# Helper functions - - -def _equals(a, b): - """Checks recursively if two dictionaries are equal""" - if isinstance(a, dict): - return all(not (key not in b or not _equals(a[key], b[key])) for key in a) - else: - if isinstance(a, bytes): - a = a.decode() - if isinstance(b, bytes): - b = b.decode() - are_equal = a == b - return are_equal if isinstance(are_equal, bool) else are_equal.all() - - return False - - -def _numpy_types(obj_value): - """Return a bool indicating whether or not the input obj_value is a numpy type object - - Recursively checks if the obj_value (could be a dictionary) is a numpy type object. - Exceptions are str and bytes. - - Returns true if object is numpy type, str, or bytes - False if any other type - """ - if not isinstance(obj_value, dict): - return isinstance(obj_value, (str, bytes)) or type(obj_value).__module__ == np.__name__ - - return all(_numpy_types(value) for _, value in obj_value.items()) - - -def _get_dict(separated_key): - """Create dummy dictionary with different datatypes - - Returns the tuple of the entire dummy dictionary created, key argument as a dictionary for _checkpoint_storage.load - function and the value for that key in the original dictionary - - For example the complete dictionary is represented by: - { - 'int1':1, - 'int2': 2, - 'int_list': [1,2,3,5,6], - 'dict1': { - 'np_array': np.arange(100), - 'dict2': {'int3': 3, 'int4': 4}, - 'str1': "onnxruntime" - }, - 'bool1': bool(True), - 'int5': 5, - 'float1': 2.345, - 'np_array_float': np.array([1.234, 2.345, 3.456]), - 'np_array_float_3_dim': np.array([[[1,2],[3,4]], [[5,6],[7,8]]]) - } - - if the input key is ['dict1', 'str1'], then the key argument returned is 'dict1/str1' - and the value corresponding to that is "onnxruntime" - - so, for the above example, the returned tuple is: - (original_dict, {'key': 'dict1/str1', "onnxruntime") - """ - test_dict = { - "int1": 1, - "int2": 2, - "int_list": [1, 2, 3, 5, 6], - "dict1": {"np_array": np.arange(100), "dict2": {"int3": 3, "int4": 4}, "str1": "onnxruntime"}, - "bool1": True, - "int5": 5, - "float1": 2.345, - "np_array_float": np.array([1.234, 2.345, 3.456]), - "np_array_float_3_dim": np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]), - } - key = "" - expected_val = test_dict - for single_key in separated_key: - key += single_key + "/" - expected_val = expected_val[single_key] - return test_dict, {"key": key} if len(separated_key) > 0 else dict(), expected_val - - -class _CustomClass: - """Custom object that encpsulates dummy values for loss, epoch and train_step""" - - def __init__(self): - self._loss = 1.23 - self._epoch = 12000 - self._train_step = 25 - - def __eq__(self, other): - if isinstance(other, _CustomClass): - return self._loss == other._loss and self._epoch == other._epoch and self._train_step == other._train_step - - -# Test fixtures - - -@pytest.yield_fixture(scope="function") -def checkpoint_storage_test_setup(): - checkpoint_dir = os.path.abspath("checkpoint_dir/") - if not os.path.exists(checkpoint_dir): - os.makedirs(checkpoint_dir, exist_ok=True) - pytest.checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.ortcp") - yield "checkpoint_storage_test_setup" - shutil.rmtree(checkpoint_dir) - - -@pytest.yield_fixture(scope="function") -def checkpoint_storage_test_parameterized_setup(request, checkpoint_storage_test_setup): - yield request.param - - -# Tests - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", - [ - _get_dict([]), - _get_dict(["int1"]), - _get_dict(["dict1"]), - _get_dict(["dict1", "dict2"]), - _get_dict(["dict1", "dict2", "int4"]), - _get_dict(["dict1", "str1"]), - _get_dict(["bool1"]), - _get_dict(["float1"]), - _get_dict(["np_array_float"]), - ], - indirect=True, -) -def test_checkpoint_storage_saved_dict_matches_loaded(checkpoint_storage_test_parameterized_setup): - to_save = checkpoint_storage_test_parameterized_setup[0] - key_arg = checkpoint_storage_test_parameterized_setup[1] - expected = checkpoint_storage_test_parameterized_setup[2] - _checkpoint_storage.save(to_save, pytest.checkpoint_path) - loaded = _checkpoint_storage.load(pytest.checkpoint_path, **key_arg) - assert _equals(loaded, expected) - assert _numpy_types(loaded) - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", - [{"int_set": {1, 2, 3, 4, 5}}, {"str_set": {"one", "two"}}, [1, 2, 3], 2.352], - indirect=True, -) -def test_checkpoint_storage_saving_non_supported_types_fails(checkpoint_storage_test_parameterized_setup): - to_save = checkpoint_storage_test_parameterized_setup - with pytest.raises(Exception): # noqa: B017 - _checkpoint_storage.save(to_save, pytest.checkpoint_path) - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", - [ - ({"int64_tensor": torch.tensor(np.arange(100))}, "int64_tensor", torch.int64, np.int64), - ({"int32_tensor": torch.tensor(np.arange(100), dtype=torch.int32)}, "int32_tensor", torch.int32, np.int32), - ({"int16_tensor": torch.tensor(np.arange(100), dtype=torch.int16)}, "int16_tensor", torch.int16, np.int16), - ({"int8_tensor": torch.tensor(np.arange(100), dtype=torch.int8)}, "int8_tensor", torch.int8, np.int8), - ({"float64_tensor": torch.tensor(np.array([1.0, 2.0]))}, "float64_tensor", torch.float64, np.float64), - ( - {"float32_tensor": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32)}, - "float32_tensor", - torch.float32, - np.float32, - ), - ( - {"float16_tensor": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float16)}, - "float16_tensor", - torch.float16, - np.float16, - ), - ], - indirect=True, -) -def test_checkpoint_storage_saving_tensor_datatype(checkpoint_storage_test_parameterized_setup): - tensor_dict = checkpoint_storage_test_parameterized_setup[0] - tensor_name = checkpoint_storage_test_parameterized_setup[1] - tensor_dtype = checkpoint_storage_test_parameterized_setup[2] - np_dtype = checkpoint_storage_test_parameterized_setup[3] - - _checkpoint_storage.save(tensor_dict, pytest.checkpoint_path) - - loaded = _checkpoint_storage.load(pytest.checkpoint_path) - assert isinstance(loaded[tensor_name], np.ndarray) - assert tensor_dict[tensor_name].dtype == tensor_dtype - assert loaded[tensor_name].dtype == np_dtype - assert (tensor_dict[tensor_name].numpy() == loaded[tensor_name]).all() - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", - [ - ({"two_dim": torch.ones([2, 4], dtype=torch.float64)}, "two_dim"), - ({"three_dim": torch.ones([2, 4, 6], dtype=torch.float64)}, "three_dim"), - ({"four_dim": torch.ones([2, 4, 6, 8], dtype=torch.float64)}, "four_dim"), - ], - indirect=True, -) -def test_checkpoint_storage_saving_multiple_dimension_tensors(checkpoint_storage_test_parameterized_setup): - tensor_dict = checkpoint_storage_test_parameterized_setup[0] - tensor_name = checkpoint_storage_test_parameterized_setup[1] - - _checkpoint_storage.save(tensor_dict, pytest.checkpoint_path) - - loaded = _checkpoint_storage.load(pytest.checkpoint_path) - assert isinstance(loaded[tensor_name], np.ndarray) - assert (tensor_dict[tensor_name].numpy() == loaded[tensor_name]).all() - - -@pytest.mark.parametrize( - "checkpoint_storage_test_parameterized_setup", [{}, {"a": {}}, {"a": {"b": {}}}], indirect=True -) -def test_checkpoint_storage_saving_and_loading_empty_dictionaries_succeeds(checkpoint_storage_test_parameterized_setup): - saved = checkpoint_storage_test_parameterized_setup - _checkpoint_storage.save(saved, pytest.checkpoint_path) - - loaded = _checkpoint_storage.load(pytest.checkpoint_path) - assert _equals(saved, loaded) - - -def test_checkpoint_storage_load_file_that_does_not_exist_fails(checkpoint_storage_test_setup): - with pytest.raises(Exception): # noqa: B017 - _checkpoint_storage.load(pytest.checkpoint_path) - - -def test_checkpoint_storage_for_custom_user_dict_succeeds(checkpoint_storage_test_setup): - custom_class = _CustomClass() - user_dict = {"tensor1": torch.tensor(np.arange(100), dtype=torch.float32), "custom_class": custom_class} - - pickled_bytes = pickle.dumps(user_dict).hex() - to_save = {"a": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32), "user_dict": pickled_bytes} - _checkpoint_storage.save(to_save, pytest.checkpoint_path) - - loaded_dict = _checkpoint_storage.load(pytest.checkpoint_path) - assert (loaded_dict["a"] == to_save["a"].numpy()).all() - try: # noqa: SIM105 - loaded_dict["user_dict"] = loaded_dict["user_dict"].decode() - except AttributeError: - pass - loaded_obj = pickle.loads(bytes.fromhex(loaded_dict["user_dict"])) - - assert torch.all(loaded_obj["tensor1"].eq(user_dict["tensor1"])) - assert loaded_obj["custom_class"] == custom_class diff --git a/orttraining/orttraining/test/python/orttraining_test_data_loader.py b/orttraining/orttraining/test/python/orttraining_test_data_loader.py index aa15b44ae0d66..0009d2d3d7e1b 100644 --- a/orttraining/orttraining/test/python/orttraining_test_data_loader.py +++ b/orttraining/orttraining/test/python/orttraining_test_data_loader.py @@ -4,8 +4,6 @@ import torch from torch.utils.data import DataLoader, Dataset -from onnxruntime.capi.ort_trainer import generate_sample - global_rng = random.Random() @@ -41,6 +39,16 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None): return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() +def generate_sample(desc, device=None): + """Generate a sample based on the description""" + # symbolic dimensions are described with strings. set symbolic dimensions to be 1 + size = [s if isinstance(s, (int)) else 1 for s in desc.shape_] + if desc.num_classes_: + return torch.randint(0, desc.num_classes_, size, dtype=desc.dtype_).to(device) + else: + return torch.randn(size, dtype=desc.dtype_).to(device) + + class OrtTestDataset(Dataset): def __init__(self, input_desc, seq_len, dataset_len, device): import copy diff --git a/orttraining/orttraining/test/python/orttraining_test_debuggability.py b/orttraining/orttraining/test/python/orttraining_test_debuggability.py deleted file mode 100644 index 499f0ba7a1ff5..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_debuggability.py +++ /dev/null @@ -1,40 +0,0 @@ -import pytest -import torch -from _test_commons import _load_pytorch_transformer_model - -from onnxruntime import set_seed -from onnxruntime.training import optim, orttrainer - -############################################################################### -# Testing starts here ######################################################### -############################################################################### - - -@pytest.mark.parametrize( - "seed, device", - [ - (24, "cuda"), - ], -) -def testORTTransformerModelExport(seed, device): - # Common setup - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": { - "check_model_export": True, - }, - "device": { - "id": device, - }, - } - ) - - # Setup for the first ORTTRainer run - torch.manual_seed(seed) - set_seed(seed) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) - first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - data, targets = batcher_fn(train_data, 0) - _ = first_trainer.train_step(data, targets) - assert first_trainer._onnx_model is not None diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis.py index 506aafbe9f618..a3e666dd404f2 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis.py @@ -27,7 +27,7 @@ def run_training_apis_python_api_tests(cwd, log): log.debug("Running: ort training api tests") - command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_python_bindings.py"] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ort_apis_py_bindings.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() @@ -37,7 +37,7 @@ def run_onnxblock_tests(cwd, log): log.debug("Running: onnxblock tests") - command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_onnxblock.py"] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ort_apis_onnxblock.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() diff --git a/orttraining/orttraining/test/python/orttraining_test_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py similarity index 100% rename from orttraining/orttraining/test/python/orttraining_test_onnxblock.py rename to orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py diff --git a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py similarity index 99% rename from orttraining/orttraining/test/python/orttraining_test_python_bindings.py rename to orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py index d5c37b3e36ee7..34d8c24ccfab4 100644 --- a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py @@ -11,7 +11,7 @@ import onnx import pytest import torch -from orttraining_test_onnxblock import _get_models +from orttraining_test_ort_apis_onnxblock import _get_models import onnxruntime.training.onnxblock as onnxblock from onnxruntime import OrtValue, SessionOptions diff --git a/orttraining/orttraining/test/python/orttraining_test_hooks.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_hooks.py similarity index 100% rename from orttraining/orttraining/test/python/orttraining_test_hooks.py rename to orttraining/orttraining/test/python/orttraining_test_ortmodule_hooks.py diff --git a/orttraining/orttraining/test/python/orttraining_test_onnx_ops_ortmodule.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py similarity index 100% rename from orttraining/orttraining/test/python/orttraining_test_onnx_ops_ortmodule.py rename to orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py deleted file mode 100644 index 45b87b32f7d64..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py +++ /dev/null @@ -1,1283 +0,0 @@ -import copy # noqa: F401 -import inspect # noqa: F401 -import math # noqa: F401 -import os -from functools import partial - -import _test_commons -import _test_helpers -import onnx -import pytest -import torch -from numpy.testing import assert_allclose - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription -from onnxruntime.capi.ort_trainer import LossScaler as Legacy_LossScaler -from onnxruntime.capi.ort_trainer import ModelDescription as Legacy_ModelDescription -from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer -from onnxruntime.training import amp, optim, orttrainer - -############################################################################### -# Helper functions ############################################################ -############################################################################### - - -def generate_random_input_from_model_desc(desc, seed=1, device="cuda:0"): - """Generates a sample input for the BERT model using the model desc""" - - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - dtype = torch.int64 - vocab_size = 30528 - num_classes = [vocab_size, 2, 2, vocab_size, 2] - dims = {"batch_size": 16, "seq_len": 1} - sample_input = [] - for index, input in enumerate(desc["inputs"]): - size = [] - for s in input[1]: - if isinstance(s, (int)): - size.append(s) - else: - size.append(dims[s] if s in dims else 1) - sample_input.append(torch.randint(0, num_classes[index], tuple(size), dtype=dtype).to(device)) - return sample_input - - -# EXPERIMENTAL HELPER FUNCTIONS - - -def bert_model_description(dynamic_shape=True): - """Creates the model description dictionary with static dimensions""" - - if dynamic_shape: - model_desc = { - "inputs": [ - ("input_ids", ["batch_size", "seq_len"]), - ( - "segment_ids", - ["batch_size", "seq_len"], - ), - ( - "input_mask", - ["batch_size", "seq_len"], - ), - ( - "masked_lm_labels", - ["batch_size", "seq_len"], - ), - ( - "next_sentence_labels", - [ - "batch_size", - ], - ), - ], - "outputs": [("loss", [], True)], - } - else: - batch_size = 16 - seq_len = 1 - model_desc = { - "inputs": [ - ("input_ids", [batch_size, seq_len]), - ( - "segment_ids", - [batch_size, seq_len], - ), - ( - "input_mask", - [batch_size, seq_len], - ), - ( - "masked_lm_labels", - [batch_size, seq_len], - ), - ( - "next_sentence_labels", - [ - batch_size, - ], - ), - ], - "outputs": [("loss", [], True)], - } - return model_desc - - -def optimizer_parameters(model): - """A method to assign different hyper parameters for different model parameter groups""" - - no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] - no_decay_param_group = [] - for initializer in model.graph.initializer: - if any(key in initializer.name for key in no_decay_keys): - no_decay_param_group.append(initializer.name) - params = [ - { - "params": no_decay_param_group, - "alpha": 0.9, - "beta": 0.999, - "lambda_coef": 0.0, - "epsilon": 1e-6, - "do_bias_correction": False, - } - ] - - return params - - -def load_bert_onnx_model(): - bert_onnx_model_path = os.path.join("testdata", "bert_toy_postprocessed.onnx") - model = onnx.load(bert_onnx_model_path) - return model - - -class CustomLossScaler(amp.LossScaler): - def __init__(self, loss_scale=float(1 << 16)): - super().__init__(loss_scale) - self._initial_loss_scale = loss_scale - self.loss_scale = loss_scale - - def reset(self): - self.loss_scale = self._initial_loss_scale - - def update(self, train_step_info): - self.loss_scale *= 0.9 - return self.loss_scale - - -# LEGACY HELPER FUNCTIONS - - -class LegacyCustomLossScaler: - def __init__(self, loss_scale=float(1 << 16)): - self._initial_loss_scale = loss_scale - self.loss_scale_ = loss_scale - - def reset(self): - self.loss_scale_ = self._initial_loss_scale - - def update_loss_scale(self, is_all_finite): - self.loss_scale_ *= 0.9 - - -def legacy_model_params(lr, device=torch.device("cuda", 0)): # noqa: B008 - legacy_model_desc = legacy_bert_model_description() - learning_rate_description = legacy_ort_trainer_learning_rate_description() - learning_rate = torch.tensor([lr]).to(device) - return (legacy_model_desc, learning_rate_description, learning_rate) - - -def legacy_ort_trainer_learning_rate_description(): - return Legacy_IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ) - - -def legacy_bert_model_description(): - input_ids_desc = Legacy_IODescription("input_ids", ["batch", "max_seq_len_in_batch"]) - segment_ids_desc = Legacy_IODescription("segment_ids", ["batch", "max_seq_len_in_batch"]) - input_mask_desc = Legacy_IODescription("input_mask", ["batch", "max_seq_len_in_batch"]) - masked_lm_labels_desc = Legacy_IODescription("masked_lm_labels", ["batch", "max_seq_len_in_batch"]) - next_sentence_labels_desc = Legacy_IODescription( - "next_sentence_labels", - [ - "batch", - ], - ) - loss_desc = Legacy_IODescription("loss", []) - - return Legacy_ModelDescription( - [input_ids_desc, segment_ids_desc, input_mask_desc, masked_lm_labels_desc, next_sentence_labels_desc], - [loss_desc], - ) - - -def legacy_optim_params_a(name): - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6, "do_bias_correction": False} - - -def legacy_optim_params_b(name): - params = ["bert.embeddings.LayerNorm.bias", "bert.embeddings.LayerNorm.weight"] - if name in params: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6, "do_bias_correction": False} - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6, "do_bias_correction": False} - - -def legacy_optim_params_c(name): - params_group = optimizer_parameters(load_bert_onnx_model()) - if name in params_group[0]["params"]: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6, "do_bias_correction": False} - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6, "do_bias_correction": False} - - -############################################################################### -# Testing starts here ######################################################### -############################################################################### - - -@pytest.mark.parametrize("dynamic_shape", [(True), (False)]) -def testToyBERTModelBasicTraining(dynamic_shape): - model_desc = bert_model_description(dynamic_shape) - model = load_bert_onnx_model() - - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions({}) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - for _i in range(10): - sample_input = generate_random_input_from_model_desc(model_desc) - output = trainer.train_step(*sample_input) - assert output.shape == torch.Size([]) - - -@pytest.mark.parametrize( - "expected_losses", - [([11.041123, 10.986166, 11.101636, 11.013366, 11.03775, 11.041175, 10.957118, 11.069563, 11.040824, 11.16437])], -) -def testToyBERTDeterministicCheck(expected_losses): - # Common setup - train_steps = 10 - device = "cuda" - seed = 1 - rtol = 1e-3 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # Modeling - model_desc = bert_model_description() - model = load_bert_onnx_model() - optimizer_parameters(model) - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train - experimental_losses = [] - for i in range(train_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - - # Check output - _test_helpers.assert_model_outputs(experimental_losses, expected_losses, rtol=rtol) - - -@pytest.mark.parametrize( - "initial_lr, lr_scheduler, expected_learning_rates, expected_losses", - [ - ( - 1.0, - optim.lr_scheduler.ConstantWarmupLRScheduler, - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - [ - 10.988012313842773, - 10.99213981628418, - 120.79301452636719, - 36.11647033691406, - 95.83200073242188, - 221.2766571044922, - 208.40316772460938, - 279.5332946777344, - 402.46380615234375, - 325.79254150390625, - ], - ), - ( - 0.5, - optim.lr_scheduler.ConstantWarmupLRScheduler, - [0.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], - [ - 10.988012313842773, - 10.99213981628418, - 52.69743347167969, - 19.741533279418945, - 83.88340759277344, - 126.39848327636719, - 91.53898620605469, - 63.62016296386719, - 102.21206665039062, - 180.1424560546875, - ], - ), - ( - 1.0, - optim.lr_scheduler.CosineWarmupLRScheduler, - [ - 0.0, - 0.9931806517013612, - 0.9397368756032445, - 0.8386407858128706, - 0.7008477123264848, - 0.5412896727361662, - 0.37725725642960045, - 0.22652592093878665, - 0.10542974530180327, - 0.02709137914968268, - ], - [ - 10.988012313842773, - 10.99213981628418, - 120.6441650390625, - 32.152557373046875, - 89.63705444335938, - 138.8782196044922, - 117.57748413085938, - 148.01927185058594, - 229.60403442382812, - 110.2930908203125, - ], - ), - ( - 1.0, - optim.lr_scheduler.LinearWarmupLRScheduler, - [ - 0.0, - 0.9473684210526315, - 0.8421052631578947, - 0.7368421052631579, - 0.631578947368421, - 0.5263157894736842, - 0.42105263157894735, - 0.3157894736842105, - 0.21052631578947367, - 0.10526315789473684, - ], - [ - 10.988012313842773, - 10.99213981628418, - 112.89633178710938, - 31.114538192749023, - 80.94029235839844, - 131.34490966796875, - 111.4329605102539, - 133.74252319335938, - 219.37344360351562, - 109.67041015625, - ], - ), - ( - 1.0, - optim.lr_scheduler.PolyWarmupLRScheduler, - [ - 0.0, - 0.9473684263157895, - 0.8421052789473684, - 0.7368421315789474, - 0.6315789842105263, - 0.5263158368421054, - 0.42105268947368424, - 0.31578954210526317, - 0.21052639473684212, - 0.10526324736842106, - ], - [ - 10.988012313842773, - 10.99213981628418, - 112.89633178710938, - 31.114538192749023, - 80.9402847290039, - 131.3447265625, - 111.43253326416016, - 133.7415008544922, - 219.37147521972656, - 109.66986083984375, - ], - ), - ], -) -def testToyBERTModelLRScheduler(initial_lr, lr_scheduler, expected_learning_rates, expected_losses): - return # TODO: re-enable after nondeterminism on backend is fixed - # Common setup - device = "cuda" - total_steps = 10 - seed = 1 - warmup = 0.05 - cycles = 0.5 - power = 1.0 - lr_end = 1e-7 - rtol = 1e-3 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # Setup LR Schedulers - if ( - lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler - or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler - ): - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) - elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) - elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) - else: - raise RuntimeError("Invalid lr_scheduler") - - # Modeling - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.AdamConfig(lr=initial_lr) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "lr_scheduler": lr_scheduler, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train - losses = [] - learning_rates = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - losses.append(trainer.train_step(*sample_input).cpu().item()) - learning_rates.append(trainer.options.lr_scheduler.get_last_lr()[0]) - - # Check output - _test_helpers.assert_model_outputs(learning_rates, expected_learning_rates, rtol=rtol) - _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol) - - -@pytest.mark.parametrize( - "loss_scaler, expected_losses", - [ - ( - None, - [ - 11.041126, - 10.986309, - 11.101673, - 11.013394, - 11.037781, - 11.041253, - 10.957072, - 11.069506, - 11.040807, - 11.164349, - ], - ), - ( - amp.DynamicLossScaler(), - [ - 11.041126, - 10.986309, - 11.101673, - 11.013394, - 11.037781, - 11.041253, - 10.957072, - 11.069506, - 11.040807, - 11.164349, - ], - ), - ( - CustomLossScaler(), - [ - 11.041126, - 10.986309, - 11.101645, - 11.013412, - 11.037757, - 11.041273, - 10.957077, - 11.069525, - 11.040765, - 11.164298, - ], - ), - ], -) -def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses): - # Common setup - total_steps = 10 - device = "cuda" - seed = 1 - rtol = 1e-3 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # Modeling - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train - losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - losses.append(trainer.train_step(*sample_input).cpu().item()) - - # Check output - _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol) - - -@pytest.mark.parametrize( - "gradient_accumulation_steps, expected_losses", - [ - ( - 1, - [ - 11.041123, - 10.986166, - 11.101636, - 11.013366, - 11.03775, - 11.041175, - 10.957118, - 11.069563, - 11.040824, - 11.16437, - ], - ), - ( - 4, - [ - 11.041123, - 10.982856, - 11.105512, - 11.006721, - 11.03358, - 11.05058, - 10.955864, - 11.059035, - 11.037753, - 11.162649, - ], - ), - ( - 7, - [ - 11.041123, - 10.982856, - 11.105512, - 11.006721, - 11.036314, - 11.055109, - 10.960751, - 11.05809, - 11.038856, - 11.159635, - ], - ), - ], -) -def testToyBERTModelGradientAccumulation(gradient_accumulation_steps, expected_losses): - # Common setup - total_steps = 10 - device = "cuda" - seed = 1 - rtol = 1e-3 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # Modeling - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train - losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - losses.append(trainer.train_step(*sample_input).cpu().item()) - - # Check output - _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol) - - -def testToyBertCheckpointBasic(): - # Common setup - seed = 1 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions({"debug": {"deterministic_compute": True}}) - - # Create ORTTrainer and save initial state in a dict - model = load_bert_onnx_model() - model_desc = bert_model_description() - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - sd = trainer.state_dict() - - ## All initializers must be present in the state_dict - ## when the specified model for ORTTRainer is an ONNX model - for param in trainer._onnx_model.graph.initializer: - assert param.name in sd["model"]["full_precision"] - - ## Modify one of the state values and load into ORTTrainer - sd["model"]["full_precision"]["bert.encoder.layer.0.attention.output.LayerNorm.weight"] += 10 - trainer.load_state_dict(sd) - - ## Save a checkpoint - ckpt_dir = "testdata" - trainer.save_checkpoint(os.path.join(ckpt_dir, "bert_toy_save_test.ortcp")) - del trainer - del model - - # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer - model2 = load_bert_onnx_model() - model_desc2 = bert_model_description() - trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config, options=opts) - trainer2.load_checkpoint(os.path.join(ckpt_dir, "bert_toy_save_test.ortcp")) - loaded_sd = trainer2.state_dict() - - # Assert whether original state and the one loaded from checkpoint matches - _test_commons.assert_all_states_close_ort(sd, loaded_sd) - - -def testToyBertCheckpointFrozenWeights(): - # Common setup - seed = 1 - total_steps = 10 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "utils": {"frozen_weights": ["bert.encoder.layer.0.attention.self.value.weight"]}, - } - ) - - # Create ORTTrainer and save initial state in a dict - model = load_bert_onnx_model() - model_desc = bert_model_description() - optim_config = optim.LambConfig() - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - # Train for a few steps - for _i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, seed) - _ = trainer.train_step(*sample_input) - sample_input = generate_random_input_from_model_desc(model_desc, seed + total_steps + 1) - # Evaluate once to get a base loss - loss = trainer.eval_step(*sample_input) - # Save checkpoint - state_dict = trainer.state_dict() - - # Load previous state into another instance of ORTTrainer - model2 = load_bert_onnx_model() - model_desc2 = bert_model_description() - optim_config2 = optim.LambConfig() - trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config2, options=opts) - trainer2.load_state_dict(state_dict) - # Evaluate once to get a base loss - ckpt_loss = trainer2.eval_step(*sample_input) - - # Must match as both trainers have the same dict state - assert_allclose(loss.cpu(), ckpt_loss.cpu()) - loaded_state_dict = trainer2.state_dict() - _test_commons.assert_all_states_close_ort(state_dict, loaded_state_dict) - - -@pytest.mark.parametrize( - "optimizer, mixedprecision_enabled", - [ - (optim.LambConfig(), False), - (optim.AdamConfig(), False), - (optim.LambConfig(), True), - (optim.AdamConfig(), True), - ], -) -def testToyBertLoadOptimState(optimizer, mixedprecision_enabled): - # Common setup - device = "cuda" - seed = 1 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - optim_config = optimizer - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": {"id": device}, - "mixed_precision": { - "enabled": mixedprecision_enabled, - }, - "distributed": {"allreduce_post_accumulation": True}, - } - ) - - # Create ORTTrainer and save initial state in a dict - model = load_bert_onnx_model() - model_desc = bert_model_description() - dummy_init_state = _test_commons.generate_dummy_optim_state(model, optimizer) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - trainer.load_state_dict(dummy_init_state) - - # Expected values - input_ids = torch.tensor( - [ - [26598], - [21379], - [19922], - [5219], - [5644], - [20559], - [23777], - [25672], - [22969], - [16824], - [16822], - [635], - [27399], - [20647], - [18519], - [15546], - ], - device=device, - ) - segment_ids = torch.tensor( - [[0], [1], [0], [1], [0], [0], [1], [0], [0], [1], [1], [0], [0], [1], [1], [1]], device=device - ) - input_mask = torch.tensor( - [[0], [0], [0], [0], [1], [1], [1], [0], [1], [1], [0], [0], [0], [1], [0], [0]], device=device - ) - masked_lm_labels = torch.tensor( - [ - [25496], - [16184], - [11005], - [16228], - [14884], - [21660], - [8678], - [23083], - [4027], - [8397], - [11921], - [1333], - [26482], - [1666], - [17925], - [27978], - ], - device=device, - ) - next_sentence_labels = torch.tensor([0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device) - - # Actual values - _ = trainer.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels) - - actual_state_dict = trainer.state_dict() - del actual_state_dict["model"] - _test_commons.assert_all_states_close_ort(actual_state_dict, dummy_init_state) - - -@pytest.mark.parametrize( - "model_params", - [ - (["bert.embeddings.LayerNorm.bias"]), - ( - [ - "bert.embeddings.LayerNorm.bias", - "bert.embeddings.LayerNorm.weight", - "bert.encoder.layer.0.attention.output.LayerNorm.bias", - ] - ), - ], -) -def testORTTrainerFrozenWeights(model_params): - device = "cuda" - total_steps = 10 - seed = 1 - - # EXPERIMENTAL API - model_desc = bert_model_description() - model = load_bert_onnx_model() - - optim_config = optim.LambConfig() - # Setup ORTTrainer WITHOUT frozen weights - opts_dict = { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - opts = orttrainer.ORTTrainerOptions(opts_dict) - - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - trainer.train_step(*sample_input) - - # All model_params must be in the session state - assert trainer._onnx_model is not None - session_state = trainer._training_session.get_state() - assert all([param in session_state for param in model_params]) - - # Setup ORTTrainer WITH frozen weights - opts_dict.update({"utils": {"frozen_weights": model_params}}) - opts = orttrainer.ORTTrainerOptions(opts_dict) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - trainer.train_step(*sample_input) - - # All model_params CANNOT be in the session state - assert trainer._onnx_model is not None - session_state = trainer._training_session.get_state() - assert not any([param in session_state for param in model_params]) - - -def testToyBERTSaveAsONNX(): - device = "cuda" - onnx_file_name = "_____temp_toy_bert_onnx_model.onnx" - if os.path.exists(onnx_file_name): - os.remove(onnx_file_name) - assert not os.path.exists(onnx_file_name) - - # Load trainer - model_desc = bert_model_description() - model = load_bert_onnx_model() - - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - ) - - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - trainer.save_as_onnx(onnx_file_name) - assert os.path.exists(onnx_file_name) - - with open(onnx_file_name, "rb") as f: - bin_str = f.read() - reload_onnx_model = onnx.load_model_from_string(bin_str) - os.remove(onnx_file_name) - - # Create a new trainer from persisted ONNX model and compare with original ONNX model - trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config, options=opts) - assert trainer_from_onnx._onnx_model is not None - assert id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model) - for initializer, loaded_initializer in zip( - trainer._onnx_model.graph.initializer, trainer_from_onnx._onnx_model.graph.initializer - ): - assert initializer.name == loaded_initializer.name - assert onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph( - trainer._onnx_model.graph - ) - _test_helpers.assert_onnx_weights(trainer, trainer_from_onnx) - - -############################################################################### -# Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############ -############################################################################### -@pytest.mark.parametrize( - "optimizer_config", - [ - (optim.AdamConfig), - # (optim.LambConfig), # TODO: re-enable after nondeterminism on backend is fixed - (optim.SGDConfig), - ], -) -def testToyBERTModelLegacyExperimentalBasicTraining(optimizer_config): - # Common setup - train_steps = 512 - - device = "cuda" - seed = 1 - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - # EXPERIMENTAL API - model_desc = bert_model_description() - model = load_bert_onnx_model() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - ) - optim_config = optimizer_config(lr=0.01) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - experimental_losses = [] - for i in range(train_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - - if optimizer_config == optim.AdamConfig: - legacy_optimizer = "AdamOptimizer" - elif optimizer_config == optim.LambConfig: - legacy_optimizer = "LambOptimizer" - elif optimizer_config == optim.SGDConfig: - legacy_optimizer = "SGDOptimizer" - else: - raise RuntimeError("Invalid optimizer_config") - - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(lr=optim_config.lr) - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - legacy_optimizer, - None, - learning_rate_description, - device, - _use_deterministic_compute=True, - ) - legacy_losses = [] - for i in range(train_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) - legacy_losses.append(leg_loss.cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses, True) - - -@pytest.mark.parametrize( - "initial_lr, lr_scheduler, legacy_lr_scheduler", - [ - (1.0, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler), - (0.5, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler), - (1.0, optim.lr_scheduler.CosineWarmupLRScheduler, _test_commons.legacy_cosine_lr_scheduler), - (1.0, optim.lr_scheduler.LinearWarmupLRScheduler, _test_commons.legacy_linear_lr_scheduler), - (1.0, optim.lr_scheduler.PolyWarmupLRScheduler, _test_commons.legacy_poly_lr_scheduler), - ], -) -def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, legacy_lr_scheduler): - ############################################################################ - # These tests require hard-coded values for 'total_steps' and 'initial_lr' # - ############################################################################ - - # Common setup - total_steps = 128 - device = "cuda" - seed = 1 - warmup = 0.05 - cycles = 0.5 - power = 1.0 - lr_end = 1e-7 - - # Setup both Experimental and Legacy LR Schedulers before the experimental loop - if ( - legacy_lr_scheduler == _test_commons.legacy_constant_lr_scheduler - or legacy_lr_scheduler == _test_commons.legacy_linear_lr_scheduler - ): - legacy_lr_scheduler = partial( - legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup - ) - elif legacy_lr_scheduler == _test_commons.legacy_cosine_lr_scheduler: - legacy_lr_scheduler = partial( - legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup, cycles=cycles - ) - elif legacy_lr_scheduler == _test_commons.legacy_poly_lr_scheduler: - legacy_lr_scheduler = partial( - legacy_lr_scheduler, - initial_lr=initial_lr, - total_steps=total_steps, - warmup=warmup, - power=power, - lr_end=lr_end, - ) - else: - raise RuntimeError("Invalid legacy_lr_scheduler") - if ( - lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler - or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler - ): - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) - elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) - elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) - else: - raise RuntimeError("Invalid lr_scheduler") - - # EXPERIMENTAL API - model_desc = bert_model_description() - model = load_bert_onnx_model() - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - optim_config = optim.AdamConfig(lr=initial_lr) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "lr_scheduler": lr_scheduler, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - experimental_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - assert_allclose(trainer.options.lr_scheduler.get_last_lr()[0], legacy_lr_scheduler(i)) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(initial_lr) - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - "AdamOptimizer", - None, - learning_rate_description, - device, - _use_deterministic_compute=True, - get_lr_this_step=legacy_lr_scheduler, - ) - legacy_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - leg_loss = legacy_trainer.train_step(*sample_input) - legacy_losses.append(leg_loss.cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) - - -@pytest.mark.parametrize( - "loss_scaler, legacy_loss_scaler", - [ - (None, Legacy_LossScaler("ort_test_input_loss_scaler", True)), - (amp.DynamicLossScaler(), Legacy_LossScaler("ort_test_input_loss_scaler", True)), - (CustomLossScaler(), LegacyCustomLossScaler()), - ], -) -def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental(loss_scaler, legacy_loss_scaler): - # Common setup - total_steps = 128 - device = "cuda" - seed = 1 - - # EXPERIMENTAL IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.AdamConfig(lr=0.001) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - experimental_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(optim_config.lr) - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - "AdamOptimizer", - None, - learning_rate_description, - device, - _use_deterministic_compute=True, - use_mixed_precision=True, - loss_scaler=legacy_loss_scaler, - ) - legacy_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) - legacy_losses.append(leg_loss.cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) - - -@pytest.mark.parametrize("gradient_accumulation_steps", [(1), (4), (7)]) -def testToyBERTModelGradientAccumulationLegacyExperimental(gradient_accumulation_steps): - # Common setup - total_steps = 128 - device = "cuda" - seed = 1 - - # EXPERIMENTAL IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - model_desc = bert_model_description() - model = load_bert_onnx_model() - optim_config = optim.AdamConfig() - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - experimental_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - loss = trainer.train_step(*sample_input) - experimental_losses.append(loss.cpu().item()) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(optim_config.lr) - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - "AdamOptimizer", - None, - learning_rate_description, - device, - _use_deterministic_compute=True, - gradient_accumulation_steps=gradient_accumulation_steps, - ) - legacy_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) - legacy_losses.append(leg_loss.cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) - - -@pytest.mark.parametrize( - "params, legacy_optim_map", - [ - # Change the hyper parameters for all parameters - ([], legacy_optim_params_a), - # Change the hyperparameters for a subset of hardcoded parameters - ( - [ - { - "params": ["bert.embeddings.LayerNorm.bias", "bert.embeddings.LayerNorm.weight"], - "alpha": 0.9, - "beta": 0.999, - "lambda_coef": 0.0, - "epsilon": 1e-6, - "do_bias_correction": False, - } - ], - legacy_optim_params_b, - ), - # Change the hyperparameters for a generated set of paramers - (optimizer_parameters(load_bert_onnx_model()), legacy_optim_params_c), - ], -) -def testToyBERTModelLegacyExperimentalCustomOptimParameters(params, legacy_optim_map): - # Common setup - total_steps = 128 - device = "cuda" - seed = 1 - - # EXPERIMENTAL API - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - model_desc = bert_model_description() - model = load_bert_onnx_model() - - optim_config = optim.AdamConfig( - params, alpha=0.9, beta=0.999, lambda_coef=0.01, epsilon=1e-6, do_bias_correction=False - ) - opts = orttrainer.ORTTrainerOptions( - { - "debug": {"deterministic_compute": True}, - "device": { - "id": device, - }, - } - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) - - experimental_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) - - # LEGACY IMPLEMENTATION - torch.manual_seed(seed) - onnxruntime.set_seed(seed) - device = torch.device(device) - model = load_bert_onnx_model() - legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(trainer.optim_config.lr) - - legacy_trainer = Legacy_ORTTrainer( - model, - None, - legacy_model_desc, - "AdamOptimizer", - legacy_optim_map, - learning_rate_description, - device, - _use_deterministic_compute=True, - ) - legacy_losses = [] - for i in range(total_steps): - sample_input = generate_random_input_from_model_desc(model_desc, i) - legacy_sample_input = [*sample_input, learning_rate] - legacy_losses.append(legacy_trainer.train_step(legacy_sample_input).cpu().item()) - - # Check results - _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py deleted file mode 100644 index d366f2cb26557..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py +++ /dev/null @@ -1,722 +0,0 @@ -from unittest.mock import Mock, patch - -import numpy as np -import onnx -import pytest -import torch -from _test_commons import _load_pytorch_transformer_model - -from onnxruntime.training import _checkpoint_storage, amp, checkpoint, optim, orttrainer # noqa: F401 - -# Helper functions - - -def _create_trainer(zero_enabled=False): - """Cerates a simple ORTTrainer for ORTTrainer functional tests""" - - device = "cuda" - optim_config = optim.LambConfig(lr=0.1) - opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} - if zero_enabled: - opts["distributed"] = { - "world_rank": 0, - "world_size": 1, - "horizontal_parallel_size": 1, - "data_parallel_size": 1, - "allreduce_post_accumulation": True, - "deepspeed_zero_optimization": {"stage": 1}, - } - model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) - trainer = orttrainer.ORTTrainer( - model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(opts) - ) - - return trainer - - -class _training_session_mock: # noqa: N801 - """Mock object for the ORTTrainer _training_session member""" - - def __init__(self, model_states, optimizer_states, partition_info): - self.model_states = model_states - self.optimizer_states = optimizer_states - self.partition_info = partition_info - - def get_model_state(self, include_mixed_precision_weights=False): - return self.model_states - - def get_optimizer_state(self): - return self.optimizer_states - - def get_partition_info_map(self): - return self.partition_info - - -def _get_load_state_dict_strict_error_arguments(): - """Return a list of tuples that can be used as parameters for test_load_state_dict_errors_when_model_key_missing - - Construct a list of tuples (training_session_state_dict, input_state_dict, error_arguments) - The load_state_dict function will compare the two state dicts (training_session_state_dict, input_state_dict) and - throw a runtime error with the missing/unexpected keys. The error arguments capture these missing/unexpected keys. - """ - - training_session_state_dict = { - "model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}, - "optimizer": { - "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(5)}, - }, - } - - # input state dictionaries - precision_key_missing = {"model": {}, "optimizer": {}} - precision_key_unexpected = {"model": {"full_precision": {}, "mixed_precision": {}}, "optimizer": {}} - model_state_key_missing = {"model": {"full_precision": {}}, "optimizer": {}} - model_state_key_unexpected = {"model": {"full_precision": {"a": 2, "b": 3, "c": 4}}, "optimizer": {}} - optimizer_model_state_key_missing = {"model": {"full_precision": {"a": 2, "b": 3}}, "optimizer": {}} - optimizer_model_state_key_unexpected = { - "model": {"full_precision": {"a": 2, "b": 3}}, - "optimizer": {"a": {}, "shared_optimizer_state": {}, "b": {}}, - } - optimizer_state_key_missing = { - "model": {"full_precision": {"a": 2, "b": 3}}, - "optimizer": {"a": {}, "shared_optimizer_state": {"step": np.arange(5)}}, - } - optimizer_state_key_unexpected = { - "model": {"full_precision": {"a": 2, "b": 3}}, - "optimizer": { - "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(5), "another_step": np.arange(1)}, - }, - } - - input_arguments = [ - (training_session_state_dict, precision_key_missing, ["full_precision"]), - (training_session_state_dict, precision_key_unexpected, ["mixed_precision"]), - (training_session_state_dict, model_state_key_missing, ["a", "b"]), - (training_session_state_dict, model_state_key_unexpected, ["c"]), - (training_session_state_dict, optimizer_model_state_key_missing, ["a", "shared_optimizer_state"]), - (training_session_state_dict, optimizer_model_state_key_unexpected, ["b"]), - (training_session_state_dict, optimizer_state_key_missing, ["Moment_1", "Moment_2"]), - (training_session_state_dict, optimizer_state_key_unexpected, ["another_step"]), - ] - - return input_arguments - - -# Tests - - -def test_empty_state_dict_when_training_session_uninitialized(): - trainer = _create_trainer() - with pytest.warns(UserWarning) as user_warning: - state_dict = trainer.state_dict() - - assert len(state_dict.keys()) == 0 - assert ( - user_warning[0].message.args[0] == "ONNX Runtime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling ORTTrainer.state_dict()." - ) - - -@patch("onnx.ModelProto") -def test_training_session_provides_empty_model_states(onnx_model_mock): - trainer = _create_trainer() - training_session_mock = _training_session_mock({}, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert len(state_dict["model"].keys()) == 0 - - -@patch("onnx.ModelProto") -def test_training_session_provides_model_states(onnx_model_mock): - trainer = _create_trainer() - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - training_session_mock = _training_session_mock(model_states, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all() - assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all() - - -@patch("onnx.ModelProto") -def test_training_session_provides_model_states_pytorch_format(onnx_model_mock): - trainer = _create_trainer() - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - training_session_mock = _training_session_mock(model_states, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict(pytorch_format=True) - assert torch.all(torch.eq(state_dict["a"], torch.tensor(np.arange(5)))) - assert torch.all(torch.eq(state_dict["b"], torch.tensor(np.arange(7)))) - - -@patch("onnx.ModelProto") -def test_onnx_graph_provides_frozen_model_states(onnx_model_mock): - trainer = _create_trainer() - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - training_session_mock = _training_session_mock(model_states, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - trainer.options.utils.frozen_weights = ["a_frozen_weight", "a_float16_weight"] - trainer._onnx_model.graph.initializer = [ - onnx.numpy_helper.from_array(np.array([1, 2, 3], dtype=np.float32), "a_frozen_weight"), - onnx.numpy_helper.from_array(np.array([4, 5, 6], dtype=np.float32), "a_non_fronzen_weight"), - onnx.numpy_helper.from_array(np.array([7, 8, 9], dtype=np.float16), "a_float16_weight"), - ] - - state_dict = trainer.state_dict() - assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all() - assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all() - assert (state_dict["model"]["full_precision"]["a_frozen_weight"] == np.array([1, 2, 3], dtype=np.float32)).all() - assert "a_non_fronzen_weight" not in state_dict["model"]["full_precision"] - assert (state_dict["model"]["full_precision"]["a_float16_weight"] == np.array([7, 8, 9], dtype=np.float32)).all() - - -@patch("onnx.ModelProto") -def test_training_session_provides_empty_optimizer_states(onnx_model_mock): - trainer = _create_trainer() - training_session_mock = _training_session_mock({}, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert len(state_dict["optimizer"].keys()) == 0 - - -@patch("onnx.ModelProto") -def test_training_session_provides_optimizer_states(onnx_model_mock): - trainer = _create_trainer() - optimizer_states = { - "model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(1)}, - } - training_session_mock = _training_session_mock({}, optimizer_states, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert (state_dict["optimizer"]["model_weight"]["Moment_1"] == np.arange(5)).all() - assert (state_dict["optimizer"]["model_weight"]["Moment_2"] == np.arange(7)).all() - assert (state_dict["optimizer"]["shared_optimizer_state"]["step"] == np.arange(1)).all() - - -@patch("onnx.ModelProto") -def test_training_session_provides_optimizer_states_pytorch_format(onnx_model_mock): - trainer = _create_trainer() - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - optimizer_states = { - "model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(1)}, - } - training_session_mock = _training_session_mock(model_states, optimizer_states, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict(pytorch_format=True) - assert "optimizer" not in state_dict - - -@patch("onnx.ModelProto") -def test_training_session_provides_empty_partition_info_map(onnx_model_mock): - trainer = _create_trainer(zero_enabled=True) - training_session_mock = _training_session_mock({}, {}, {}) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert len(state_dict["partition_info"].keys()) == 0 - - -@patch("onnx.ModelProto") -def test_training_session_provides_partition_info_map(onnx_model_mock): - trainer = _create_trainer(zero_enabled=True) - partition_info = {"a": {"original_dim": [1, 2, 3]}} - training_session_mock = _training_session_mock({}, {}, partition_info) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert state_dict["partition_info"]["a"]["original_dim"] == [1, 2, 3] - - -@patch("onnx.ModelProto") -def test_training_session_provides_all_states(onnx_model_mock): - trainer = _create_trainer(zero_enabled=True) - model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} - optimizer_states = { - "model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(1)}, - } - partition_info = {"a": {"original_dim": [1, 2, 3]}} - training_session_mock = _training_session_mock(model_states, optimizer_states, partition_info) - trainer._training_session = training_session_mock - trainer._onnx_model = onnx_model_mock() - - state_dict = trainer.state_dict() - assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all() - assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all() - assert (state_dict["optimizer"]["model_weight"]["Moment_1"] == np.arange(5)).all() - assert (state_dict["optimizer"]["model_weight"]["Moment_2"] == np.arange(7)).all() - assert (state_dict["optimizer"]["shared_optimizer_state"]["step"] == np.arange(1)).all() - assert state_dict["partition_info"]["a"]["original_dim"] == [1, 2, 3] - - -def test_load_state_dict_holds_when_training_session_not_initialized(): - trainer = _create_trainer() - state_dict = { - "model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}, - "optimizer": { - "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(5)}, - }, - } - assert not trainer._load_state_dict - state_dict = trainer.load_state_dict(state_dict) - assert trainer._load_state_dict - - -@pytest.mark.parametrize( - "state_dict, input_state_dict, error_key", - [ - ( - {"model": {}, "optimizer": {}}, - {"model": {}, "optimizer": {}, "trainer_options": {"optimizer_name": "LambOptimizer"}}, - "train_step_info", - ), - ( - {"optimizer": {}, "train_step_info": {"optimization_step": 0, "step": 0}}, - { - "optimizer": {}, - "trainer_options": {"optimizer_name": "LambOptimizer"}, - "train_step_info": {"optimization_step": 0, "step": 0}, - }, - "model", - ), - ( - {"model": {}, "train_step_info": {"optimization_step": 0, "step": 0}}, - { - "model": {}, - "trainer_options": {"optimizer_name": "LambOptimizer"}, - "train_step_info": {"optimization_step": 0, "step": 0}, - }, - "optimizer", - ), - ], -) -def test_load_state_dict_warns_when_model_optimizer_key_missing(state_dict, input_state_dict, error_key): - trainer = _create_trainer() - trainer._training_session = _training_session_mock({}, {}, {}) - trainer.state_dict = Mock(return_value=state_dict) - trainer._update_onnx_model_initializers = Mock() - trainer._init_session = Mock() - with patch("onnx.ModelProto") as onnx_model_mock: - trainer._onnx_model = onnx_model_mock() - trainer._onnx_model.graph.initializer = [] - with pytest.warns(UserWarning) as user_warning: - trainer.load_state_dict(input_state_dict) - - assert user_warning[0].message.args[0] == f"Missing key: {error_key} in state_dict" - - -@pytest.mark.parametrize("state_dict, input_state_dict, error_keys", _get_load_state_dict_strict_error_arguments()) -def test_load_state_dict_errors_when_state_dict_mismatch(state_dict, input_state_dict, error_keys): - trainer = _create_trainer() - trainer._training_session = _training_session_mock({}, {}, {}) - trainer.state_dict = Mock(return_value=state_dict) - with pytest.raises(RuntimeError) as runtime_error: - trainer.load_state_dict(input_state_dict) - - assert any(key in str(runtime_error.value) for key in error_keys) - - -@patch("onnx.ModelProto") -def test_load_state_dict_loads_the_states_and_inits_training_session(onnx_model_mock): - trainer = _create_trainer() - training_session_state_dict = { - "model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}, - "optimizer": { - "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, - "shared_optimizer_state": {"step": np.arange(1)}, - }, - } - - input_state_dict = { - "model": {"full_precision": {"a": np.array([1, 2]), "b": np.array([3, 4])}}, - "optimizer": { - "a": {"Moment_1": np.array([5, 6]), "Moment_2": np.array([7, 8])}, - "shared_optimizer_state": {"step": np.array([9])}, - }, - "trainer_options": {"optimizer_name": "LambOptimizer"}, - } - trainer._training_session = _training_session_mock({}, {}, {}) - trainer.state_dict = Mock(return_value=training_session_state_dict) - trainer._onnx_model = onnx_model_mock() - trainer._onnx_model.graph.initializer = [ - onnx.numpy_helper.from_array(np.arange(20, dtype=np.float32), "a"), - onnx.numpy_helper.from_array(np.arange(25, dtype=np.float32), "b"), - ] - trainer._update_onnx_model_initializers = Mock() - trainer._init_session = Mock() - - trainer.load_state_dict(input_state_dict) - - loaded_initializers, _ = trainer._update_onnx_model_initializers.call_args - state_dict_to_load, _ = trainer._init_session.call_args - - assert "a" in loaded_initializers[0] - assert (loaded_initializers[0]["a"] == np.array([1, 2])).all() - assert "b" in loaded_initializers[0] - assert (loaded_initializers[0]["b"] == np.array([3, 4])).all() - - assert (state_dict_to_load[0]["a"]["Moment_1"] == np.array([5, 6])).all() - assert (state_dict_to_load[0]["a"]["Moment_2"] == np.array([7, 8])).all() - assert (state_dict_to_load[0]["shared_optimizer_state"]["step"] == np.array([9])).all() - - -@patch("onnxruntime.training._checkpoint_storage.save") -def test_save_checkpoint_calls_checkpoint_storage_save(save_mock): - trainer = _create_trainer() - state_dict = {"model": {}, "optimizer": {}} - trainer.state_dict = Mock(return_value=state_dict) - - trainer.save_checkpoint("abc") - - save_args, _ = save_mock.call_args - assert "model" in save_args[0] - assert not bool(save_args[0]["model"]) - assert "optimizer" in save_args[0] - assert not bool(save_args[0]["optimizer"]) - assert save_args[1] == "abc" - - -@patch("onnxruntime.training._checkpoint_storage.save") -def test_save_checkpoint_exclude_optimizer_states(save_mock): - trainer = _create_trainer() - state_dict = {"model": {}, "optimizer": {}} - trainer.state_dict = Mock(return_value=state_dict) - - trainer.save_checkpoint("abc", include_optimizer_states=False) - - save_args, _ = save_mock.call_args - assert "model" in save_args[0] - assert not bool(save_args[0]["model"]) - assert "optimizer" not in save_args[0] - assert save_args[1] == "abc" - - -@patch("onnxruntime.training._checkpoint_storage.save") -def test_save_checkpoint_user_dict(save_mock): - trainer = _create_trainer() - state_dict = {"model": {}, "optimizer": {}} - trainer.state_dict = Mock(return_value=state_dict) - - trainer.save_checkpoint("abc", user_dict={"abc": np.arange(4)}) - - save_args, _ = save_mock.call_args - assert "user_dict" in save_args[0] - assert save_args[0]["user_dict"] == _checkpoint_storage.to_serialized_hex({"abc": np.arange(4)}) - - -@patch("onnxruntime.training._checkpoint_storage.load") -@patch("onnxruntime.training.checkpoint.aggregate_checkpoints") -def test_load_checkpoint(aggregate_checkpoints_mock, load_mock): - trainer = _create_trainer() - trainer_options = { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - } - state_dict = { - "model": {}, - "optimizer": {}, - "trainer_options": { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - }, - } - trainer.load_state_dict = Mock() - - load_mock.side_effect = [trainer_options, state_dict] - trainer.load_checkpoint("abc") - - args_list = load_mock.call_args_list - load_args, load_kwargs = args_list[0] - assert load_args[0] == "abc" - assert load_kwargs["key"] == "trainer_options" - load_args, load_kwargs = args_list[1] - assert load_args[0] == "abc" - assert "key" not in load_kwargs - assert not aggregate_checkpoints_mock.called - - -@patch("onnxruntime.training._checkpoint_storage.load") -@patch("onnxruntime.training.checkpoint.aggregate_checkpoints") -@pytest.mark.parametrize( - "trainer_options", - [ - { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(4), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(4), - "zero_stage": np.int64(1), - }, - { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(1), - }, - { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(1), - }, - ], -) -def test_load_checkpoint_aggregation_required_zero_enabled(aggregate_checkpoints_mock, load_mock, trainer_options): - trainer = _create_trainer() - trainer.load_state_dict = Mock() - - load_mock.side_effect = [trainer_options] - trainer.load_checkpoint("abc") - - args_list = load_mock.call_args_list - load_args, load_kwargs = args_list[0] - assert load_args[0] == "abc" - assert load_kwargs["key"] == "trainer_options" - assert aggregate_checkpoints_mock.called - call_args, _ = aggregate_checkpoints_mock.call_args - assert call_args[0] == tuple(["abc"]) - - -@patch("onnxruntime.training._checkpoint_storage.load") -@patch("onnxruntime.training.checkpoint.aggregate_checkpoints") -def test_load_checkpoint_user_dict(aggregate_checkpoints_mock, load_mock): - trainer = _create_trainer() - trainer_options = { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - } - state_dict = { - "model": {}, - "optimizer": {}, - "trainer_options": { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - }, - "user_dict": _checkpoint_storage.to_serialized_hex({"array": torch.tensor(np.arange(5))}), - } - trainer.load_state_dict = Mock() - - load_mock.side_effect = [trainer_options, state_dict] - user_dict = trainer.load_checkpoint("abc") - - assert torch.all(torch.eq(user_dict["array"], torch.tensor(np.arange(5)))) - - -@patch("onnxruntime.training._checkpoint_storage.load") -def test_checkpoint_aggregation(load_mock): - trainer_options1 = { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(2), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(2), - "zero_stage": np.int64(1), - "optimizer_name": b"Adam", - } - trainer_options2 = { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(1), - "world_size": np.int64(2), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(2), - "zero_stage": np.int64(1), - "optimizer_name": b"Adam", - } - - state_dict1 = { - "model": {"full_precision": {"optimizer_sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}}, - "optimizer": { - "optimizer_sharded": { - "Moment_1": np.array([9, 8, 7]), - "Moment_2": np.array([99, 88, 77]), - "Step": np.array([5]), - }, - "non_sharded": { - "Moment_1": np.array([666, 555, 444]), - "Moment_2": np.array([6666, 5555, 4444]), - "Step": np.array([55]), - }, - }, - "trainer_options": { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - "optimizer_name": b"Adam", - }, - "partition_info": {"optimizer_sharded": {"original_dim": np.array([2, 3])}}, - } - - state_dict2 = { - "model": {"full_precision": {"optimizer_sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}}, - "optimizer": { - "optimizer_sharded": { - "Moment_1": np.array([6, 5, 4]), - "Moment_2": np.array([66, 55, 44]), - "Step": np.array([5]), - }, - "non_sharded": { - "Moment_1": np.array([666, 555, 444]), - "Moment_2": np.array([6666, 5555, 4444]), - "Step": np.array([55]), - }, - }, - "trainer_options": { - "mixed_precision": np.bool_(False), - "world_rank": np.int64(1), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - "optimizer_name": b"Adam", - }, - "partition_info": {"optimizer_sharded": {"original_dim": np.array([2, 3])}}, - } - - load_mock.side_effect = [trainer_options1, trainer_options2, trainer_options1, state_dict1, state_dict2] - state_dict = checkpoint.aggregate_checkpoints(["abc", "def"], pytorch_format=False) - - assert (state_dict["model"]["full_precision"]["optimizer_sharded"] == np.array([1, 2, 3])).all() - assert (state_dict["model"]["full_precision"]["non_sharded"] == np.array([11, 22, 33])).all() - assert (state_dict["optimizer"]["optimizer_sharded"]["Moment_1"] == np.array([[9, 8, 7], [6, 5, 4]])).all() - assert (state_dict["optimizer"]["optimizer_sharded"]["Moment_2"] == np.array([[99, 88, 77], [66, 55, 44]])).all() - assert (state_dict["optimizer"]["optimizer_sharded"]["Step"] == np.array([5])).all() - assert (state_dict["optimizer"]["non_sharded"]["Moment_1"] == np.array([666, 555, 444])).all() - assert (state_dict["optimizer"]["non_sharded"]["Moment_2"] == np.array([6666, 5555, 4444])).all() - assert (state_dict["optimizer"]["non_sharded"]["Step"] == np.array([55])).all() - - assert state_dict["trainer_options"]["mixed_precision"] is False - assert state_dict["trainer_options"]["world_rank"] == 0 - assert state_dict["trainer_options"]["world_size"] == 1 - assert state_dict["trainer_options"]["horizontal_parallel_size"] == 1 - assert state_dict["trainer_options"]["data_parallel_size"] == 1 - assert state_dict["trainer_options"]["zero_stage"] == 0 - assert state_dict["trainer_options"]["optimizer_name"] == b"Adam" - - -@patch("onnxruntime.training._checkpoint_storage.load") -def test_checkpoint_aggregation_mixed_precision(load_mock): - trainer_options1 = { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(0), - "world_size": np.int64(2), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(2), - "zero_stage": np.int64(1), - "optimizer_name": b"Adam", - } - trainer_options2 = { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(1), - "world_size": np.int64(2), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(2), - "zero_stage": np.int64(1), - "optimizer_name": b"Adam", - } - - state_dict1 = { - "model": {"full_precision": {"sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}}, - "optimizer": { - "sharded": {"Moment_1": np.array([9, 8, 7]), "Moment_2": np.array([99, 88, 77]), "Step": np.array([5])}, - "non_sharded": { - "Moment_1": np.array([666, 555, 444]), - "Moment_2": np.array([6666, 5555, 4444]), - "Step": np.array([55]), - }, - }, - "trainer_options": { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(0), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - "optimizer_name": b"Adam", - }, - "partition_info": {"sharded": {"original_dim": np.array([2, 3])}}, - } - - state_dict2 = { - "model": {"full_precision": {"sharded": np.array([4, 5, 6]), "non_sharded": np.array([11, 22, 33])}}, - "optimizer": { - "sharded": {"Moment_1": np.array([6, 5, 4]), "Moment_2": np.array([66, 55, 44]), "Step": np.array([5])}, - "non_sharded": { - "Moment_1": np.array([666, 555, 444]), - "Moment_2": np.array([6666, 5555, 4444]), - "Step": np.array([55]), - }, - }, - "trainer_options": { - "mixed_precision": np.bool_(True), - "world_rank": np.int64(1), - "world_size": np.int64(1), - "horizontal_parallel_size": np.int64(1), - "data_parallel_size": np.int64(1), - "zero_stage": np.int64(0), - "optimizer_name": b"Adam", - }, - "partition_info": {"sharded": {"original_dim": np.array([2, 3])}}, - } - - load_mock.side_effect = [trainer_options1, trainer_options2, trainer_options1, state_dict1, state_dict2] - state_dict = checkpoint.aggregate_checkpoints(["abc", "def"], pytorch_format=False) - - assert (state_dict["model"]["full_precision"]["sharded"] == np.array([[1, 2, 3], [4, 5, 6]])).all() - assert (state_dict["model"]["full_precision"]["non_sharded"] == np.array([11, 22, 33])).all() - assert (state_dict["optimizer"]["sharded"]["Moment_1"] == np.array([[9, 8, 7], [6, 5, 4]])).all() - assert (state_dict["optimizer"]["sharded"]["Moment_2"] == np.array([[99, 88, 77], [66, 55, 44]])).all() - assert (state_dict["optimizer"]["sharded"]["Step"] == np.array([5])).all() - assert (state_dict["optimizer"]["non_sharded"]["Moment_1"] == np.array([666, 555, 444])).all() - assert (state_dict["optimizer"]["non_sharded"]["Moment_2"] == np.array([6666, 5555, 4444])).all() - assert (state_dict["optimizer"]["non_sharded"]["Step"] == np.array([55])).all() - - assert state_dict["trainer_options"]["mixed_precision"] is True - assert state_dict["trainer_options"]["world_rank"] == 0 - assert state_dict["trainer_options"]["world_size"] == 1 - assert state_dict["trainer_options"]["horizontal_parallel_size"] == 1 - assert state_dict["trainer_options"]["data_parallel_size"] == 1 - assert state_dict["trainer_options"]["zero_stage"] == 0 - assert state_dict["trainer_options"]["optimizer_name"] == b"Adam" diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py deleted file mode 100644 index fa13625f0ddac..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ /dev/null @@ -1,2460 +0,0 @@ -import inspect -import os -import tempfile -from functools import partial - -import _test_commons -import _test_helpers -import onnx -import pytest -import torch -import torch.nn.functional as F -from numpy.testing import assert_allclose -from packaging.version import Version as StrictVersion - -from onnxruntime import SessionOptions, set_seed -from onnxruntime.capi.ort_trainer import LossScaler as Legacy_LossScaler -from onnxruntime.capi.ort_trainer import ORTTrainer as Legacy_ORTTrainer -from onnxruntime.training import PropagateCastOpsStrategy, TrainStepInfo, _utils, amp -from onnxruntime.training import model_desc_validation as md_val -from onnxruntime.training import optim, orttrainer, orttrainer_options - -############################################################################### -# Testing starts here ######################################################### -############################################################################### - -pytorch_110 = StrictVersion(".".join(torch.__version__.split(".")[:2])) >= StrictVersion("1.10.0") - - -def get_model_opset(model_onnx): - for op in model_onnx.opset_import: - if op.domain == "": - return op.version - return None - - -@pytest.mark.parametrize( - "test_input", - [({}), ({"batch": {}, "device": {}, "distributed": {}, "mixed_precision": {}, "utils": {}, "_internal_use": {}})], -) -def testORTTrainerOptionsDefaultValues(test_input): - """Test different ways of using default values for incomplete input""" - - expected_values = { - "batch": {"gradient_accumulation_steps": 1}, - "device": {"id": "cuda", "mem_limit": 0}, - "distributed": { - "world_rank": 0, - "world_size": 1, - "local_rank": 0, - "data_parallel_size": 1, - "horizontal_parallel_size": 1, - "pipeline_parallel": { - "pipeline_parallel_size": 1, - "num_pipeline_micro_batches": 1, - "pipeline_cut_info_string": "", - "sliced_schema": {}, - "sliced_axes": {}, - "sliced_tensor_names": [], - }, - "allreduce_post_accumulation": False, - "deepspeed_zero_optimization": { - "stage": 0, - }, - "enable_adasum": False, - }, - "lr_scheduler": None, - "mixed_precision": {"enabled": False, "loss_scaler": None}, - "graph_transformer": { - "attn_dropout_recompute": False, - "gelu_recompute": False, - "transformer_layer_recompute": False, - "number_recompute_layers": 0, - "propagate_cast_ops_config": {"strategy": PropagateCastOpsStrategy.FLOOD_FILL, "level": 1, "allow": []}, - }, - "utils": { - "frozen_weights": [], - "grad_norm_clip": True, - "memory_efficient_gradient": False, - "run_symbolic_shape_infer": False, - }, - "debug": { - "deterministic_compute": False, - "check_model_export": False, - "graph_save_paths": { - "model_after_graph_transforms_path": "", - "model_with_gradient_graph_path": "", - "model_with_training_graph_path": "", - "model_with_training_graph_after_optimization_path": "", - }, - }, - "_internal_use": { - "enable_internal_postprocess": True, - "extra_postprocess": None, - "onnx_opset_version": 14, - "enable_onnx_contrib_ops": True, - }, - "provider_options": {}, - "session_options": None, - } - - actual_values = orttrainer_options.ORTTrainerOptions(test_input) - assert actual_values._validated_opts == expected_values - - -@pytest.mark.parametrize( - "input,error_msg", - [ - ( - {"mixed_precision": {"enabled": 1}}, - "Invalid options: {'mixed_precision': [{'enabled': ['must be of boolean type']}]}", - ) - ], -) -def testORTTrainerOptionsInvalidMixedPrecisionEnabledSchema(input, error_msg): - """Test an invalid input based on schema validation error message""" - - with pytest.raises(ValueError) as e: - orttrainer_options.ORTTrainerOptions(input) - assert str(e.value) == error_msg - - -@pytest.mark.parametrize( - "input_dict,input_dtype,output_dtype", - [ - ( - {"inputs": [("in0", [])], "outputs": [("out0", []), ("out1", [])]}, - (torch.int,), - ( - torch.float, - torch.int32, - ), - ), - ({"inputs": [("in0", ["batch", 2, 3])], "outputs": [("out0", [], True)]}, (torch.int8,), (torch.int16,)), - ( - { - "inputs": [ - ("in0", []), - ("in1", [1]), - ("in2", [1, 2]), - ("in3", [1000, "dyn_ax1"]), - ("in4", ["dyn_ax1", "dyn_ax2", "dyn_ax3"]), - ], - "outputs": [("out0", [], True), ("out1", [1], False), ("out2", [1, "dyn_ax1", 3])], - }, - ( - torch.float, - torch.uint8, - torch.bool, - torch.double, - torch.half, - ), - (torch.float, torch.float, torch.int64), - ), - ], -) -def testORTTrainerModelDescValidSchemas(input_dict, input_dtype, output_dtype): - r"""Test different ways of using default values for incomplete input""" - - model_description = md_val._ORTTrainerModelDesc(input_dict) - - # Validating hard-coded learning rate description - assert model_description.learning_rate.name == md_val.LEARNING_RATE_IO_DESCRIPTION_NAME - assert model_description.learning_rate.shape == [1] - assert model_description.learning_rate.dtype == torch.float32 - - # Validating model description from user - for idx, i_desc in enumerate(model_description.inputs): - assert isinstance(i_desc, model_description._InputDescription) - assert len(i_desc) == 2 - assert input_dict["inputs"][idx][0] == i_desc.name - assert input_dict["inputs"][idx][1] == i_desc.shape - for idx, o_desc in enumerate(model_description.outputs): - assert isinstance(o_desc, model_description._OutputDescription) - assert len(o_desc) == 3 - assert input_dict["outputs"][idx][0] == o_desc.name - assert input_dict["outputs"][idx][1] == o_desc.shape - is_loss = input_dict["outputs"][idx][2] if len(input_dict["outputs"][idx]) == 3 else False - assert is_loss == o_desc.is_loss - - # Set all_finite name and check its description - model_description.all_finite = md_val.ALL_FINITE_IO_DESCRIPTION_NAME - assert model_description.all_finite.name == md_val.ALL_FINITE_IO_DESCRIPTION_NAME - assert model_description.all_finite.shape == [1] - assert model_description.all_finite.dtype == torch.bool - - # Set loss_scale_input and check its description - model_description.loss_scale_input = md_val.LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME - assert model_description.loss_scale_input.name == md_val.LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME - assert model_description.loss_scale_input.shape == [] - assert model_description.loss_scale_input.dtype == torch.float32 - - # Append type to inputs/outputs tuples - for idx, i_desc in enumerate(model_description.inputs): # noqa: B007 - model_description.add_type_to_input_description(idx, input_dtype[idx]) - for idx, o_desc in enumerate(model_description.outputs): # noqa: B007 - model_description.add_type_to_output_description(idx, output_dtype[idx]) - - # Verify inputs/outputs tuples are replaced by the typed counterparts - for idx, i_desc in enumerate(model_description.inputs): - assert isinstance(i_desc, model_description._InputDescriptionTyped) - assert input_dtype[idx] == i_desc.dtype - for idx, o_desc in enumerate(model_description.outputs): - assert isinstance(o_desc, model_description._OutputDescriptionTyped) - assert output_dtype[idx] == o_desc.dtype - - -@pytest.mark.parametrize( - "input_dict,error_msg", - [ - ( - {"inputs": [(True, [])], "outputs": [(True, [])]}, - "Invalid model_desc: {'inputs': [{0: ['the first element of the tuple (aka name) must be a string']}], " - "'outputs': [{0: ['the first element of the tuple (aka name) must be a string']}]}", - ), - ( - {"inputs": [("in1", None)], "outputs": [("out1", None)]}, - "Invalid model_desc: {'inputs': [{0: ['the second element of the tuple (aka shape) must be a list']}], " - "'outputs': [{0: ['the second element of the tuple (aka shape) must be a list']}]}", - ), - ( - {"inputs": [("in1", [])], "outputs": [("out1", [], None)]}, - "Invalid model_desc: {'outputs': [{0: ['the third element of the tuple (aka is_loss) must be a boolean']}]}", - ), - ( - {"inputs": [("in1", [True])], "outputs": [("out1", [True])]}, - "Invalid model_desc: {'inputs': [{0: ['each shape must be either a string or integer']}], " - "'outputs': [{0: ['each shape must be either a string or integer']}]}", - ), - ( - {"inputs": [("in1", [])], "outputs": [("out1", [], True), ("out2", [], True)]}, - "Invalid model_desc: {'outputs': [{1: ['only one is_loss can bet set to True']}]}", - ), - ( - {"inputz": [("in1", [])], "outputs": [("out1", [], True)]}, - "Invalid model_desc: {'inputs': ['required field'], 'inputz': ['unknown field']}", - ), - ( - {"inputs": [("in1", [])], "outputz": [("out1", [], True)]}, - "Invalid model_desc: {'outputs': ['required field'], 'outputz': ['unknown field']}", - ), - ], -) -def testORTTrainerModelDescInvalidSchemas(input_dict, error_msg): - r"""Test different ways of using default values for incomplete input""" - with pytest.raises(ValueError) as e: - md_val._ORTTrainerModelDesc(input_dict) - assert str(e.value) == error_msg - - -def testDynamicLossScaler(): - rtol = 1e-7 - default_scaler = amp.loss_scaler.DynamicLossScaler() - - # Initial state - train_step_info = orttrainer.TrainStepInfo(optim.LambConfig()) - assert_allclose(default_scaler.loss_scale, float(1 << 16), rtol=rtol, err_msg="loss scale mismatch") - assert default_scaler.up_scale_window == 2000 - assert_allclose(default_scaler.min_loss_scale, 1.0, rtol=rtol, err_msg="min loss scale mismatch") - assert_allclose(default_scaler.max_loss_scale, float(1 << 24), rtol=rtol, err_msg="max loss scale mismatch") - - # Performing 9*2000 updates to cover all branches of LossScaler.update(train_step_info.all_finite=True) - loss_scale = float(1 << 16) - for cycles in range(1, 10): - # 1999 updates without overflow produces 1999 stable steps - for i in range(1, 2000): - new_loss_scale = default_scaler.update(train_step_info) - assert default_scaler._stable_steps_count == i - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg=f"loss scale mismatch at update {i}") - - # 2000th update without overflow doubles the loss and zero stable steps until max_loss_scale is reached - new_loss_scale = default_scaler.update(train_step_info) - if cycles <= 8: - loss_scale *= 2 - assert default_scaler._stable_steps_count == 0 - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") - - # After 8 cycles, loss scale should be float(1 << 16)*(2**8) - assert_allclose(new_loss_scale, float(1 << 16) * (2**8), rtol=rtol, err_msg="loss scale mismatch") - - # After 9 cycles, loss scale reaches max_loss_scale and it is not doubled from that point on - loss_scale = float(1 << 16) * (2**8) - for count in range(1, 2050): - new_loss_scale = default_scaler.update(train_step_info) - assert default_scaler._stable_steps_count == (count % 2000) - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") - - # Setting train_step_info.all_finite = False to test down scaling - train_step_info.all_finite = False - - # Performing 24 updates to half the loss scale each time - loss_scale = float(1 << 16) * (2**8) - for count in range(1, 25): # noqa: B007 - new_loss_scale = default_scaler.update(train_step_info) - loss_scale /= 2 - assert default_scaler._stable_steps_count == 0 - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") - - # After 24 updates with gradient overflow, loss scale is 1.0 - assert_allclose(new_loss_scale, 1.0, rtol=rtol, err_msg="loss scale mismatch") - - # After 25 updates, min_loss_scale is reached and loss scale is not halfed from that point on - for count in range(1, 5): # noqa: B007 - new_loss_scale = default_scaler.update(train_step_info) - assert default_scaler._stable_steps_count == 0 - assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") - - -def testDynamicLossScalerCustomValues(): - rtol = 1e-7 - scaler = amp.loss_scaler.DynamicLossScaler( - automatic_update=False, loss_scale=3, up_scale_window=7, min_loss_scale=5, max_loss_scale=10 - ) - assert scaler.automatic_update is False - assert_allclose(scaler.loss_scale, 3, rtol=rtol, err_msg="loss scale mismatch") - assert_allclose(scaler.min_loss_scale, 5, rtol=rtol, err_msg="min loss scale mismatch") - assert_allclose(scaler.max_loss_scale, 10, rtol=rtol, err_msg="max loss scale mismatch") - assert scaler.up_scale_window == 7 - - -def testTrainStepInfo(): - """Test valid initializations of TrainStepInfo""" - - optimizer_config = optim.LambConfig() - fetches = ["out1", "out2"] - step_info = orttrainer.TrainStepInfo( - optimizer_config=optimizer_config, all_finite=False, fetches=fetches, optimization_step=123, step=456 - ) - assert step_info.optimizer_config == optimizer_config - assert step_info.all_finite is False - assert step_info.fetches == fetches - assert step_info.optimization_step == 123 - assert step_info.step == 456 - - step_info = orttrainer.TrainStepInfo(optimizer_config) - assert step_info.optimizer_config == optimizer_config - assert step_info.all_finite is True - assert step_info.fetches == [] - assert step_info.optimization_step == 0 - assert step_info.step == 0 - - -@pytest.mark.parametrize( - "invalid_input", - [ - (-1), - ("Hello"), - ], -) -def testTrainStepInfoInvalidInput(invalid_input): - """Test invalid initialization of TrainStepInfo""" - optimizer_config = optim.LambConfig() - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config=invalid_input) - - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config, all_finite=invalid_input) - - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config, fetches=invalid_input) - - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config, optimization_step=invalid_input) - - with pytest.raises(AssertionError): - orttrainer.TrainStepInfo(optimizer_config, step=invalid_input) - - -@pytest.mark.parametrize( - "optim_name,lr,alpha,default_alpha", - [ - ("AdamOptimizer", 0.1, 0.2, None), - ("LambOptimizer", 0.2, 0.3, None), - ("SGDOptimizer", 0.3, 0.4, None), - ("SGDOptimizer", 0.3, 0.4, 0.5), - ], -) -def testOptimizerConfig(optim_name, lr, alpha, default_alpha): - """Test initialization of _OptimizerConfig""" - defaults = {"lr": lr, "alpha": alpha} - params = [{"params": ["fc1.weight", "fc2.weight"]}] - if default_alpha is not None: - params[0].update({"alpha": default_alpha}) - else: - params[0].update({"alpha": alpha}) - cfg = optim.config._OptimizerConfig(name=optim_name, params=params, defaults=defaults) - - assert cfg.name == optim_name - rtol = 1e-07 - assert_allclose(defaults["lr"], cfg.lr, rtol=rtol, err_msg="lr mismatch") - - # 1:1 mapping between defaults and params's hyper parameters - for param in params: - for k in param: - if k != "params": - assert k in cfg.defaults, "hyper parameter {k} not present in one of the parameter params" - for k in cfg.defaults: - for param in cfg.params: - assert k in param, "hyper parameter {k} not present in one of the parameter params" - - -@pytest.mark.parametrize( - "optim_name,defaults,params", - [ - ("AdamOptimizer", {"lr": -1}, []), # invalid lr - ("FooOptimizer", {"lr": 0.001}, []), # invalid name - ("SGDOptimizer", [], []), # invalid type(defaults) - (optim.AdamConfig, {"lr": 0.003}, []), # invalid type(name) - ("AdamOptimizer", {"lr": None}, []), # missing 'lr' hyper parameter - ("SGDOptimizer", {"lr": 0.004}, {}), # invalid type(params) - # invalid type(params[i]) - ("AdamOptimizer", {"lr": 0.005, "alpha": 2}, [[]]), - # missing 'params' at 'params' - ("AdamOptimizer", {"lr": 0.005, "alpha": 2}, [{"alpha": 1}]), - # missing 'alpha' at 'defaults' - ("AdamOptimizer", {"lr": 0.005}, [{"params": "param1", "alpha": 1}]), - ], -) -def testOptimizerConfigInvalidInputs(optim_name, defaults, params): - """Test invalid initialization of _OptimizerConfig""" - - with pytest.raises(AssertionError): - optim.config._OptimizerConfig(name=optim_name, params=params, defaults=defaults) - - -def testOptimizerConfigSGD(): - """Test initialization of SGD""" - cfg = optim.SGDConfig() - assert cfg.name == "SGDOptimizer" - - rtol = 1e-07 - assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") - - cfg = optim.SGDConfig(lr=0.002) - assert_allclose(0.002, cfg.lr, rtol=rtol, err_msg="lr mismatch") - - # SGD does not support params - with pytest.raises(AssertionError) as e: - params = [{"params": ["layer1.weight"], "lr": 0.1}] - optim.SGDConfig(params=params, lr=0.002) - assert_allclose(0.002, cfg.lr, rtol=rtol, err_msg="lr mismatch") - assert str(e.value) == "'params' must be an empty list for SGD optimizer" - - -def testOptimizerConfigAdam(): - """Test initialization of Adam""" - cfg = optim.AdamConfig() - assert cfg.name == "AdamOptimizer" - - rtol = 1e-7 - assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") - assert_allclose(0.9, cfg.alpha, rtol=rtol, err_msg="alpha mismatch") - assert_allclose(0.999, cfg.beta, rtol=rtol, err_msg="beta mismatch") - assert_allclose(0.0, cfg.lambda_coef, rtol=rtol, err_msg="lambda_coef mismatch") - assert_allclose(1e-8, cfg.epsilon, rtol=rtol, err_msg="epsilon mismatch") - assert_allclose(1.0, cfg.max_norm_clip, rtol=rtol, err_msg="max_norm_clip mismatch") - assert cfg.do_bias_correction is True, "lambda_coef mismatch" - assert cfg.weight_decay_mode == optim.AdamConfig.DecayMode.BEFORE_WEIGHT_UPDATE, "weight_decay_mode mismatch" - - -def testOptimizerConfigLamb(): - """Test initialization of Lamb""" - cfg = optim.LambConfig() - assert cfg.name == "LambOptimizer" - rtol = 1e-7 - assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") - assert_allclose(0.9, cfg.alpha, rtol=rtol, err_msg="alpha mismatch") - assert_allclose(0.999, cfg.beta, rtol=rtol, err_msg="beta mismatch") - assert_allclose(0.0, cfg.lambda_coef, rtol=rtol, err_msg="lambda_coef mismatch") - assert cfg.ratio_min == float("-inf"), "ratio_min mismatch" - assert cfg.ratio_max == float("inf"), "ratio_max mismatch" - assert_allclose(1e-6, cfg.epsilon, rtol=rtol, err_msg="epsilon mismatch") - assert_allclose(1.0, cfg.max_norm_clip, rtol=rtol, err_msg="max_norm_clip mismatch") - assert cfg.do_bias_correction is False, "do_bias_correction mismatch" - - -@pytest.mark.parametrize("optim_name", [("Adam"), ("Lamb")]) -def testOptimizerConfigParams(optim_name): - rtol = 1e-7 - params = [{"params": ["layer1.weight"], "alpha": 0.1}] - if optim_name == "Adam": - cfg = optim.AdamConfig(params=params, alpha=0.2) - elif optim_name == "Lamb": - cfg = optim.LambConfig(params=params, alpha=0.2) - else: - raise ValueError("invalid input") - assert len(cfg.params) == 1, "params should have length 1" - assert_allclose(cfg.params[0]["alpha"], 0.1, rtol=rtol, err_msg="invalid lr on params[0]") - - -@pytest.mark.parametrize("optim_name", [("Adam"), ("Lamb")]) -def testOptimizerConfigInvalidParams(optim_name): - # lr is not supported within params - with pytest.raises(AssertionError) as e: - params = [{"params": ["layer1.weight"], "lr": 0.1}] - if optim_name == "Adam": - optim.AdamConfig(params=params, lr=0.2) - elif optim_name == "Lamb": - optim.LambConfig(params=params, lr=0.2) - else: - raise ValueError("invalid input") - assert str(e.value) == "'lr' is not supported inside params" - - -def testLinearLRSchedulerCreation(): - total_steps = 10 - warmup = 0.05 - - lr_scheduler = optim.lr_scheduler.LinearWarmupLRScheduler(total_steps, warmup) - - # Initial state - assert lr_scheduler.total_steps == total_steps - assert lr_scheduler.warmup == warmup - - -@pytest.mark.parametrize( - "lr_scheduler,expected_values", - [ - (optim.lr_scheduler.ConstantWarmupLRScheduler, [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0]), - ( - optim.lr_scheduler.CosineWarmupLRScheduler, - [ - 0.0, - 0.9763960957919413, - 0.9059835861602854, - 0.7956724530494887, - 0.6563036824392345, - 0.5015739416158049, - 0.34668951940611276, - 0.2068719061737831, - 0.09586187986225325, - 0.0245691111902418, - ], - ), - (optim.lr_scheduler.LinearWarmupLRScheduler, [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.8, 0.6, 0.4, 0.2]), - ( - optim.lr_scheduler.PolyWarmupLRScheduler, - [ - 0.0, - 0.9509018036072144, - 0.9008016032064128, - 0.8507014028056112, - 0.8006012024048097, - 0.750501002004008, - 0.7004008016032064, - 0.6503006012024048, - 0.6002004008016032, - 0.5501002004008015, - ], - ), - ], -) -def testLRSchedulerUpdateImpl(lr_scheduler, expected_values): - # Test tolerance - rtol = 1e-03 - - # Initial state - initial_lr = 1 - total_steps = 10 - warmup = 0.5 - optimizer_config = optim.SGDConfig(lr=initial_lr) - lr_scheduler = lr_scheduler(total_steps, warmup) - - # First half is warmup - for optimization_step in range(total_steps): - # Emulate ORTTRainer.train_step() call that updates its train_step_info - train_step_info = TrainStepInfo(optimizer_config=optimizer_config, optimization_step=optimization_step) - - lr_scheduler._step(train_step_info) - lr_list = lr_scheduler.get_last_lr() - assert len(lr_list) == 1 - assert_allclose(lr_list[0], expected_values[optimization_step], rtol=rtol, err_msg="lr mismatch") - - -def testInstantiateORTTrainerOptions(): - session_options = SessionOptions() - session_options.enable_mem_pattern = False - provider_options = {"EP1": {"key": "val"}} - opts = {"session_options": session_options, "provider_options": provider_options} - opts = orttrainer.ORTTrainerOptions(opts) - assert opts.session_options.enable_mem_pattern is False - assert opts._validated_opts["provider_options"]["EP1"]["key"] == "val" - - -@pytest.mark.parametrize( - "step_fn, lr_scheduler, expected_lr_values, device", - [ - ("train_step", None, None, "cuda"), - ("eval_step", None, None, "cpu"), - ( - "train_step", - optim.lr_scheduler.ConstantWarmupLRScheduler, - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0], - "cpu", - ), - ( - "train_step", - optim.lr_scheduler.CosineWarmupLRScheduler, - [ - 0.0, - 0.2, - 0.4, - 0.6, - 0.8, - 1.0, - 0.9045084971874737, - 0.6545084971874737, - 0.34549150281252633, - 0.09549150281252633, - ], - "cuda", - ), - ( - "train_step", - optim.lr_scheduler.LinearWarmupLRScheduler, - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.8, 0.6, 0.4, 0.2], - "cpu", - ), - ( - "train_step", - optim.lr_scheduler.PolyWarmupLRScheduler, - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.80000002, 0.60000004, 0.40000006000000005, 0.20000007999999997], - "cuda", - ), - ], -) -def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values, device): - total_steps = 1 - initial_lr = 1.0 - rtol = 1e-3 - - # PyTorch Transformer model as example - opts = {"device": {"id": device}} - if lr_scheduler: - total_steps = 10 - opts.update({"lr_scheduler": lr_scheduler(total_steps=total_steps, warmup=0.5)}) - opts = orttrainer.ORTTrainerOptions(opts) - optim_config = optim.LambConfig(lr=initial_lr) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _test_commons._load_pytorch_transformer_model( - device - ) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - - # Run a train or evaluation step - if step_fn == "eval_step": - data, targets = batcher_fn(val_data, 0) - elif step_fn == "train_step": - data, targets = batcher_fn(train_data, 0) - else: - raise ValueError("Invalid step_fn") - - # Export model to ONNX - if step_fn == "eval_step": - step_fn = trainer.eval_step - output = trainer.eval_step(data, targets) - elif step_fn == "train_step": - step_fn = trainer.train_step - for i in range(total_steps): - output = trainer.train_step(data, targets) - if lr_scheduler: - lr_list = trainer.options.lr_scheduler.get_last_lr() - assert_allclose(lr_list[0], expected_lr_values[i], rtol=rtol, err_msg="lr mismatch") - else: - raise ValueError("Invalid step_fn") - assert trainer._onnx_model is not None - - # Check output shape after train/eval step - for out, desc in zip(output, trainer.model_desc.outputs): - if trainer.loss_fn and desc.is_loss: - continue - assert list(out.size()) == desc.shape - - # Check name, shape and dtype of the first len(forward.parameters) ORT graph inputs - sig = inspect.signature(model.forward) - for i in range(len(sig.parameters.keys())): - input_name = trainer.model_desc.inputs[i][0] - input_dim = trainer.model_desc.inputs[i][1] - input_type = trainer.model_desc.inputs[i][2] - - assert trainer._onnx_model.graph.input[i].name == input_name - for dim_idx, dim in enumerate(trainer._onnx_model.graph.input[i].type.tensor_type.shape.dim): - assert input_dim[dim_idx] == dim.dim_value - assert input_type == _utils.dtype_onnx_to_torch( - trainer._onnx_model.graph.input[i].type.tensor_type.elem_type - ) - - opset = get_model_opset(trainer._onnx_model) - - # Check name, shape and dtype of the ORT graph outputs - for i in range(len(trainer.model_desc.outputs)): - output_name = trainer.model_desc.outputs[i][0] - output_dim = trainer.model_desc.outputs[i][1] - output_type = trainer.model_desc.outputs[i][3] - - assert trainer._onnx_model.graph.output[i].name == output_name - for dim_idx, dim in enumerate(trainer._onnx_model.graph.output[i].type.tensor_type.shape.dim): - if opset is None or opset <= 12: - assert output_dim[dim_idx] == dim.dim_value - assert output_type == _utils.dtype_onnx_to_torch( - trainer._onnx_model.graph.output[i].type.tensor_type.elem_type - ) - - # Save current model as ONNX as a file - file_name = os.path.join("_____temp_onnx_model.onnx") - trainer.save_as_onnx(file_name) - assert os.path.exists(file_name) - with open(file_name, "rb") as f: - bin_str = f.read() - reload_onnx_model = onnx.load_model_from_string(bin_str) - os.remove(file_name) - - # Create a new trainer from persisted ONNX model and compare with original ONNX model - trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config) - step_fn(data, targets) - assert trainer_from_onnx._onnx_model is not None - assert id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model) - assert trainer_from_onnx._onnx_model == trainer._onnx_model - assert trainer_from_onnx._onnx_model.graph == trainer._onnx_model.graph - assert onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph( - trainer._onnx_model.graph - ) - - -@pytest.mark.parametrize("seed, device", [(0, "cpu"), (24, "cuda")]) -def testORTDeterministicCompute(seed, device): - # Common setup - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - {"debug": {"deterministic_compute": True}, "device": {"id": device, "mem_limit": 10 * 1024 * 1024}} - ) - - # Setup for the first ORTTRainer run - torch.manual_seed(seed) - set_seed(seed) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - data, targets = batcher_fn(train_data, 0) - _ = first_trainer.train_step(data, targets) - assert first_trainer._onnx_model is not None - - # Setup for the second ORTTRainer run - torch.manual_seed(seed) - set_seed(seed) - model, _, _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device) - second_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - _ = second_trainer.train_step(data, targets) - assert second_trainer._onnx_model is not None - - # Compare two different instances with identical setup - assert id(first_trainer._onnx_model) != id(second_trainer._onnx_model) - _test_helpers.assert_onnx_weights(first_trainer, second_trainer) - - -@pytest.mark.parametrize( - "seed,device,expected_loss,fetches", - [ - (321, "cuda", [10.5774, 10.4403, 10.4175, 10.2886, 10.2760], False), - (321, "cuda", [10.5774, 10.4403, 10.4175, 10.2886, 10.2760], True), - ], -) -def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches): - return # TODO: re-enable after nondeterminism on backend is fixed. update numbers - - rtol = 1e-3 - total_steps = len(expected_loss) - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - loss_scaler = amp.DynamicLossScaler() - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _test_commons._load_pytorch_transformer_model( - device - ) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - if fetches: - trainer._train_step_info.fetches = ["loss"] - loss = trainer.train_step(data, targets) - else: - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu()) - - # Eval once just to test fetches in action - val_data, val_targets = batcher_fn(val_data, 0) - if fetches: - trainer._train_step_info.fetches = ["loss"] - loss = trainer.eval_step(val_data, val_targets) - trainer._train_step_info.fetches = [] - loss, _ = trainer.eval_step(val_data, val_targets) - - # Compare loss to ground truth computed from current ORTTrainer API - _test_helpers.assert_model_outputs(expected_loss, actual_loss, True, rtol=rtol) - assert trainer._onnx_model is not None - - -def _recompute_data(): - device_capability_major = torch.cuda.get_device_capability()[0] - if device_capability_major == 7: # V100 for Dev machine - expected_loss = { - 12: [10.5598, 10.4591, 10.3477, 10.2726, 10.1945], - 14: [10.54088, 10.498755, 10.386827, 10.338747, 10.262459], - } - return [ - (False, False, False, 0, expected_loss), # no recompute - (True, False, False, 0, expected_loss), # attn_dropout recompute - (False, True, False, 0, expected_loss), # gelu recompute - (False, False, True, 0, expected_loss), # transformer_layer recompute - (False, False, True, 1, expected_loss), # transformer_layer recompute with 1 layer - ] - elif device_capability_major == 5: # M60 for CI machines - expected_loss = { - 12: [10.5445, 10.4389, 10.3480, 10.2627, 10.2113], - 14: [10.5445, 10.4389, 10.3480, 10.2627, 10.2113], - } - return [ - (False, False, False, 0, expected_loss), # no recompute - (True, False, False, 0, expected_loss), # attn_dropout recompute - (False, True, False, 0, expected_loss), # gelu recompute - (False, False, True, 0, expected_loss), # transformer_layer recompute - (False, False, True, 1, expected_loss), # transformer_layer recompute with 1 layer - ] - - -@pytest.mark.parametrize("attn_dropout, gelu, transformer_layer, number_layers, expected_loss", _recompute_data()) -def testORTTrainerRecompute(attn_dropout, gelu, transformer_layer, number_layers, expected_loss): - seed = 321 - device = "cuda" - rtol = 1e-3 - total_steps = len(expected_loss[12]) - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "graph_transformer": { - "attn_dropout_recompute": attn_dropout, - "gelu_recompute": gelu, - "transformer_layer_recompute": transformer_layer, - "number_recompute_layers": number_layers, - }, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _test_commons._load_pytorch_transformer_model( - device - ) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu()) - - # Compare loss to ground truth computed from current ORTTrainer API - assert trainer._onnx_model is not None - opset = get_model_opset(trainer._onnx_model) - _test_helpers.assert_model_outputs(expected_loss[opset], actual_loss, True, rtol=rtol) - - -@pytest.mark.parametrize( - "seed,device,gradient_accumulation_steps,total_steps,expected_loss", - [ - ( - 0, - "cuda", - 1, - 12, - [ - 10.5368022919, - 10.4146203995, - 10.3635568619, - 10.2650547028, - 10.2284049988, - 10.1304626465, - 10.0853414536, - 9.9987659454, - 9.9472427368, - 9.8832416534, - 9.8223171234, - 9.8222122192, - ], - ), - ( - 42, - "cuda", - 3, - 12, - [ - 10.6455879211, - 10.6247081757, - 10.6361322403, - 10.5187482834, - 10.5345087051, - 10.5487670898, - 10.4833698273, - 10.4600019455, - 10.4535751343, - 10.3774127960, - 10.4144191742, - 10.3757553101, - ], - ), - ( - 123, - "cuda", - 7, - 12, - [ - 10.5353469849, - 10.5261383057, - 10.5240392685, - 10.5013713837, - 10.5678377151, - 10.5452117920, - 10.5184345245, - 10.4271221161, - 10.4458627701, - 10.4864749908, - 10.4416503906, - 10.4467563629, - ], - ), - ( - 321, - "cuda", - 12, - 12, - [ - 10.5773944855, - 10.5428829193, - 10.5974750519, - 10.5416746140, - 10.6009902954, - 10.5684127808, - 10.5759754181, - 10.5636739731, - 10.5613927841, - 10.5825119019, - 10.6031589508, - 10.6199369431, - ], - ), - ], -) -def testORTTrainerGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps, expected_loss): - return # TODO: re-enable after nondeterminism on backend is fixed. update numbers - rtol = 1e-3 - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu()) - - # Compare legacy vs experimental APIs - _test_helpers.assert_model_outputs(expected_loss, actual_loss, rtol=rtol) - - -@pytest.mark.parametrize( - "dynamic_axes", - [ - (True), - (False), - ], -) -def testORTTrainerDynamicShape(dynamic_axes): - # Common setup - device = "cuda" - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions({}) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model( - device, dynamic_axes=dynamic_axes - ) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - total_steps = 10 - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - if dynamic_axes: - # Forcing batches with different sizes to exercise dynamic shapes - data = data[: -(i + 1)] - targets = targets[: -(i + 1) * data.size(1)] - _, _ = trainer.train_step(data, targets) - - assert trainer._onnx_model is not None - - -@pytest.mark.parametrize( - "enable_onnx_contrib_ops", - [ - (True), - (False), - ], -) -def testORTTrainerInternalUseContribOps(enable_onnx_contrib_ops): - # Common setup - device = "cuda" - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions({"_internal_use": {"enable_onnx_contrib_ops": enable_onnx_contrib_ops}}) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - data, targets = batcher_fn(train_data, 0) - if not enable_onnx_contrib_ops and not pytorch_110: - with pytest.raises(Exception): # noqa: B017 - _, _ = trainer.train_step(data, targets) - else: - _, _ = trainer.train_step(data, targets) - - -@pytest.mark.parametrize( - "model_params", - [ - ( - [ - "decoder.weight", - "transformer_encoder.layers.0.linear1.bias", - "transformer_encoder.layers.0.linear2.weight", - "transformer_encoder.layers.1.self_attn.out_proj.weight", - "transformer_encoder.layers.1.self_attn.out_proj.bias", - ] - ), - ], -) -def testORTTrainerFrozenWeights(model_params): - # Common setup - device = "cuda" - total_steps = 10 - - # Setup ORTTrainer WITHOUT frozen weights - options = orttrainer.ORTTrainerOptions({}) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - _, _ = trainer.train_step(data, targets) - - # All model_params must be in the session state - assert trainer._onnx_model is not None - session_state = trainer._training_session.get_state() - assert all([param in session_state for param in model_params]) - - # Setup ORTTrainer WITH frozen weights - options = orttrainer.ORTTrainerOptions({"utils": {"frozen_weights": model_params}}) - model, _, _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - _, _ = trainer.train_step(data, targets) - - # All model_params CANNOT be in the session state - assert trainer._onnx_model is not None - session_state = trainer._training_session.get_state() - assert not all([param in session_state for param in model_params]) - - -@pytest.mark.parametrize( - "loss_scaler, optimizer_config, gradient_accumulation_steps", - [ - (None, optim.AdamConfig(), 1), - (None, optim.LambConfig(), 1), - (None, optim.SGDConfig(), 1), - (amp.DynamicLossScaler(), optim.AdamConfig(), 1), - (amp.DynamicLossScaler(), optim.LambConfig(), 5), - # (amp.DynamicLossScaler(), optim.SGDConfig(), 1), # SGD doesnt support fp16 - ], -) -def testORTTrainerStateDictWrapModelLossFn(loss_scaler, optimizer_config, gradient_accumulation_steps): - # Common setup - seed = 1 - - class LinearModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 4) - - def forward(self, y=None, x=None): - if y is not None: - return self.linear(x) + y - else: - return self.linear(x) + torch.ones(2, 4) - - model_desc = { - "inputs": [ - ("x", [2, 2]), - ( - "label", - [ - 2, - ], - ), - ], - "outputs": [("loss", [], True), ("output", [2, 4])], - } - - # Dummy data - data1 = torch.randn(2, 2) - label1 = torch.tensor([0, 1], dtype=torch.int64) - data2 = torch.randn(2, 2) - label2 = torch.tensor([0, 1], dtype=torch.int64) - - # Setup training based on test parameters - opts = { - "debug": {"deterministic_compute": True}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - } - if loss_scaler: - opts["mixed_precision"] = {"enabled": True, "loss_scaler": loss_scaler} - opts = orttrainer.ORTTrainerOptions(opts) - - # Training session 1 - torch.manual_seed(seed) - set_seed(seed) - pt_model = LinearModel() - - def loss_fn(x, label): - return F.nll_loss(F.log_softmax(x, dim=1), label) - - trainer = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts) - - # Check state_dict keys before train. Must be empty - state_dict = trainer.state_dict() - assert state_dict == {} - - # Train once and check initial state - trainer.train_step(x=data1, label=label1) - state_dict = trainer.state_dict() - assert all([weight in state_dict["model"]["full_precision"] for weight in ["linear.bias", "linear.weight"]]) - - # Initialize training session 2 from state of Training 1 - torch.manual_seed(seed) - set_seed(seed) - trainer2 = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts) - trainer2.load_state_dict(state_dict) - - # Verify state was loaded properly - _test_commons.assert_all_states_close_ort(state_dict, trainer2._load_state_dict.args[0]) - - # Perform a second step in both training session 1 and 2 and verify they match - trainer.train_step(x=data2, label=label2) - state_dict = trainer.state_dict() - trainer2.train_step(x=data2, label=label2) - state_dict2 = trainer2.state_dict() - _test_commons.assert_all_states_close_ort(state_dict, state_dict2) - - -def testORTTrainerNonPickableModel(): - # Common setup - import threading - - seed = 1 - - class UnpickableModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 4) - self._lock = threading.Lock() - - def forward(self, y=None, x=None): - with self._lock: - if y is not None: - return self.linear(x) + y - else: - return self.linear(x) + torch.ones(2, 4) - - model_desc = { - "inputs": [ - ("x", [2, 2]), - ( - "label", - [ - 2, - ], - ), - ], - "outputs": [("loss", [], True), ("output", [2, 4])], - } - - # Dummy data - data = torch.randn(2, 2) - label = torch.tensor([0, 1], dtype=torch.int64) - - # Setup training based on test parameters - opts = orttrainer.ORTTrainerOptions({"debug": {"deterministic_compute": True}}) - - # Training session - torch.manual_seed(seed) - set_seed(seed) - pt_model = UnpickableModel() - - def loss_fn(x, label): - return F.nll_loss(F.log_softmax(x, dim=1), label) - - optim_config = optim.AdamConfig() - trainer = orttrainer.ORTTrainer(pt_model, model_desc, optim_config, loss_fn=loss_fn, options=opts) - - # Train must succeed despite warning - _, _ = trainer.train_step(data, label) - - -############################################################################### -# Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############ -############################################################################### - - -@pytest.mark.parametrize("seed,device", [(1234, "cuda")]) -def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device): - # Common data - rtol = 1e-7 - total_steps = 5 - - # Setup for the experimental ORTTRainer run - torch.manual_seed(seed) - set_seed(seed) - optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - # Training loop - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - _ = trainer.train_step(data, targets) - - # Setup for the legacy ORTTrainer run - torch.manual_seed(seed) - set_seed(seed) - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer( - model, my_loss, model_desc, "LambOptimizer", None, lr_desc, device, _use_deterministic_compute=True - ) - # Training loop - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - _, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) - - # Compare legacy vs experimental APIs - _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=rtol) - - -@pytest.mark.parametrize( - "seed,device", - [ - (321, "cuda"), - ], -) -def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device): - # Common data - total_steps = 128 - - # Setup experimental API - torch.manual_seed(seed) - set_seed(seed) - loss_scaler = amp.DynamicLossScaler() - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, - "debug": { - "deterministic_compute": True, - }, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - experimental_loss = [] - experimental_preds_dtype = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - exp_loss, exp_preds = trainer.train_step(data, targets) - experimental_loss.append(exp_loss.cpu()) - experimental_preds_dtype.append(exp_preds.dtype) - - # Setup legacy API - torch.manual_seed(seed) - set_seed(seed) - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - loss_scaler = Legacy_LossScaler("ort_test_input_loss_scalar", True) - legacy_trainer = Legacy_ORTTrainer( - model, - my_loss, - model_desc, - "LambOptimizer", - None, - lr_desc, - device=device, - _use_deterministic_compute=True, - use_mixed_precision=True, - loss_scaler=loss_scaler, - ) - # Training loop - legacy_loss = [] - legacy_preds_dtype = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - leg_loss, leg_preds = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) - legacy_loss.append(leg_loss.cpu()) - legacy_preds_dtype.append(leg_preds.dtype) - - # Compare legacy vs experimental APIs - assert experimental_preds_dtype == legacy_preds_dtype - _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer) - _test_helpers.assert_model_outputs(legacy_loss, experimental_loss) - - -@pytest.mark.parametrize( - "seed,device,gradient_accumulation_steps,total_steps", - [ - (0, "cuda", 1, 12), - (42, "cuda", 3, 12), - (123, "cuda", 7, 12), - (321, "cuda", 12, 12), - ], -) -def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps): - # Common data - torch.set_printoptions(precision=10) - - # Setup experimental API - torch.manual_seed(seed) - set_seed(seed) - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - experimental_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - exp_loss, _ = trainer.train_step(data, targets) - experimental_loss.append(exp_loss.cpu()) - - # Setup legacy API - torch.manual_seed(seed) - set_seed(seed) - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer( - model, - my_loss, - model_desc, - "LambOptimizer", - None, - lr_desc, - device=device, - _use_deterministic_compute=True, - gradient_accumulation_steps=gradient_accumulation_steps, - ) - # Training loop - legacy_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - leg_loss, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) - legacy_loss.append(leg_loss.cpu()) - - # Compare legacy vs experimental APIs - _test_helpers.assert_model_outputs(legacy_loss, experimental_loss) - - -@pytest.mark.parametrize( - "seed,device,optimizer_config,lr_scheduler, get_lr_this_step", - [ - ( - 0, - "cuda", - optim.AdamConfig, - optim.lr_scheduler.ConstantWarmupLRScheduler, - _test_commons.legacy_constant_lr_scheduler, - ), - ( - 0, - "cuda", - optim.LambConfig, - optim.lr_scheduler.ConstantWarmupLRScheduler, - _test_commons.legacy_constant_lr_scheduler, - ), - ( - 0, - "cuda", - optim.SGDConfig, - optim.lr_scheduler.ConstantWarmupLRScheduler, - _test_commons.legacy_constant_lr_scheduler, - ), - ( - 42, - "cuda", - optim.AdamConfig, - optim.lr_scheduler.LinearWarmupLRScheduler, - _test_commons.legacy_linear_lr_scheduler, - ), - ( - 42, - "cuda", - optim.LambConfig, - optim.lr_scheduler.LinearWarmupLRScheduler, - _test_commons.legacy_linear_lr_scheduler, - ), - ( - 42, - "cuda", - optim.SGDConfig, - optim.lr_scheduler.LinearWarmupLRScheduler, - _test_commons.legacy_linear_lr_scheduler, - ), - ( - 123, - "cuda", - optim.AdamConfig, - optim.lr_scheduler.CosineWarmupLRScheduler, - _test_commons.legacy_cosine_lr_scheduler, - ), - ( - 123, - "cuda", - optim.LambConfig, - optim.lr_scheduler.CosineWarmupLRScheduler, - _test_commons.legacy_cosine_lr_scheduler, - ), - ( - 123, - "cuda", - optim.SGDConfig, - optim.lr_scheduler.CosineWarmupLRScheduler, - _test_commons.legacy_cosine_lr_scheduler, - ), - ( - 321, - "cuda", - optim.AdamConfig, - optim.lr_scheduler.PolyWarmupLRScheduler, - _test_commons.legacy_poly_lr_scheduler, - ), - ( - 321, - "cuda", - optim.LambConfig, - optim.lr_scheduler.PolyWarmupLRScheduler, - _test_commons.legacy_poly_lr_scheduler, - ), - ( - 321, - "cuda", - optim.SGDConfig, - optim.lr_scheduler.PolyWarmupLRScheduler, - _test_commons.legacy_poly_lr_scheduler, - ), - ], -) -def testORTTrainerLegacyAndExperimentalLRScheduler(seed, device, optimizer_config, lr_scheduler, get_lr_this_step): - # Common data - total_steps = 10 - lr = 0.001 - warmup = 0.5 - cycles = 0.5 - power = 1.0 - lr_end = 1e-7 - torch.set_printoptions(precision=10) - - # Setup experimental API - torch.manual_seed(seed) - set_seed(seed) - if ( - lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler - or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler - ): - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) - elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) - elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: - lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) - else: - raise RuntimeError("Invalid lr_scheduler") - - options = orttrainer.ORTTrainerOptions( - {"device": {"id": device}, "debug": {"deterministic_compute": True}, "lr_scheduler": lr_scheduler} - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optimizer_config(lr=lr) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - experimental_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - exp_loss, exp_preds = trainer.train_step(data, targets) - experimental_loss.append(exp_loss.cpu()) - - # Setup legacy API - torch.manual_seed(seed) - set_seed(seed) - - if optimizer_config == optim.AdamConfig: - legacy_optimizer_config = "AdamOptimizer" - elif optimizer_config == optim.LambConfig: - legacy_optimizer_config = "LambOptimizer" - elif optimizer_config == optim.SGDConfig: - legacy_optimizer_config = "SGDOptimizer" - else: - raise RuntimeError("Invalid optimizer_config") - - if ( - get_lr_this_step == _test_commons.legacy_constant_lr_scheduler - or get_lr_this_step == _test_commons.legacy_linear_lr_scheduler - ): - get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup) - elif get_lr_this_step == _test_commons.legacy_cosine_lr_scheduler: - get_lr_this_step = partial( - get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, cycles=cycles - ) - elif get_lr_this_step == _test_commons.legacy_poly_lr_scheduler: - get_lr_this_step = partial( - get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end - ) - else: - raise RuntimeError("Invalid get_lr_this_step") - - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer( - model, - my_loss, - model_desc, - legacy_optimizer_config, - None, - lr_desc, - device=device, - _use_deterministic_compute=True, - get_lr_this_step=get_lr_this_step, - ) - # Training loop - legacy_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - leg_loss, leg_preds = legacy_trainer.train_step(data, targets) - legacy_loss.append(leg_loss.cpu()) - - # Compare legacy vs experimental APIs - _test_helpers.assert_model_outputs(legacy_loss, experimental_loss) - - -def testLossScalerLegacyAndExperimentalFullCycle(): - orttrainer.TrainStepInfo( - optimizer_config=optim.LambConfig(lr=0.001), all_finite=True, fetches=[], optimization_step=0, step=0 - ) - new_ls = amp.DynamicLossScaler() - old_ls = Legacy_LossScaler("ort_test_input_loss_scaler", True) - - # Initial state - train_step_info = orttrainer.TrainStepInfo(optim.LambConfig()) - assert_allclose(new_ls.loss_scale, old_ls.loss_scale_) - assert new_ls.up_scale_window == old_ls.up_scale_window_ - assert_allclose(new_ls.min_loss_scale, old_ls.min_loss_scale_) - assert_allclose(new_ls.max_loss_scale, old_ls.max_loss_scale_) - - # Performing 9*2000 updates to cover all branches of LossScaler.update(train_step_info.all_finite=True) - for _cycles in range(1, 10): - # 1999 updates without overflow produces 1999 stable steps - for _i in range(1, 2000): - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - # 2000th update without overflow doubles the loss and zero stable steps until max_loss_scale is reached - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - # After 8 cycles, loss scale should be float(1 << 16)*(2**8) - assert_allclose(new_loss_scale, old_loss_scale) - - # After 9 cycles, loss scale reaches max_loss_scale and it is not doubled from that point on - for _count in range(1, 2050): - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - # Setting train_step_info.all_finite = False to test down scaling - train_step_info.all_finite = False - - # Performing 24 updates to half the loss scale each time - for _count in range(1, 25): - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - # After 24 updates with gradient overflow, loss scale is 1.0 - assert_allclose(new_loss_scale, old_loss_scale) - - # After 25 updates, min_loss_scale is reached and loss scale is not halfed from that point on - for _count in range(1, 5): - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - - -def testLossScalerLegacyAndExperimentalRandomAllFinite(): - new_ls = amp.DynamicLossScaler() - old_ls = Legacy_LossScaler("ort_test_input_loss_scaler", True) - - # Initial state - train_step_info = orttrainer.TrainStepInfo(optim.LambConfig()) - assert_allclose(new_ls.loss_scale, old_ls.loss_scale_) - assert new_ls.up_scale_window == old_ls.up_scale_window_ - assert_allclose(new_ls.min_loss_scale, old_ls.min_loss_scale_) - assert_allclose(new_ls.max_loss_scale, old_ls.max_loss_scale_) - - import random - - out = [] - for _ in range(1, 64): - train_step_info.all_finite = bool(random.getrandbits(1)) - new_loss_scale = new_ls.update(train_step_info) - old_ls.update_loss_scale(train_step_info.all_finite) - old_loss_scale = old_ls.loss_scale_ - assert new_ls._stable_steps_count == old_ls.stable_steps_ - assert_allclose(new_loss_scale, old_loss_scale) - out.append(new_loss_scale) - assert new_loss_scale > 1e-7 - - -def testORTTrainerRunSymbolicShapeInfer(): - # Common data - seed = 0 - total_steps = 12 - device = "cuda" - torch.set_printoptions(precision=10) - - # Setup without symbolic shape inference - torch.manual_seed(seed) - set_seed(seed) - options = orttrainer.ORTTrainerOptions({"device": {"id": device}, "debug": {"deterministic_compute": True}}) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - expected_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - expected_loss.append(loss.cpu()) - - # Setup with symbolic shape inference - torch.manual_seed(seed) - set_seed(seed) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001) - options.utils.run_symbolic_shape_infer = True - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - # Training loop - new_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - new_loss.append(loss.cpu()) - - # Setup with symbolic shape inference in legacy API - torch.manual_seed(seed) - set_seed(seed) - model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer( - model, - my_loss, - model_desc, - "LambOptimizer", - None, - lr_desc, - device=device, - run_symbolic_shape_infer=True, - _use_deterministic_compute=True, - ) - # Training loop - legacy_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = legacy_trainer.train_step(data, targets, torch.tensor([optim_config.lr])) - legacy_loss.append(loss.cpu()) - - # Compare losses - _test_helpers.assert_model_outputs(new_loss, expected_loss) - _test_helpers.assert_model_outputs(legacy_loss, expected_loss) - - -@pytest.mark.parametrize( - "test_input", - [ - ( - { - "distributed": {"enable_adasum": True}, - } - ) - ], -) -def testORTTrainerOptionsEnabledAdasumFlag(test_input): - """Test the enabled_adasum flag values when set enabled""" - - actual_values = orttrainer_options.ORTTrainerOptions(test_input) - assert actual_values.distributed.enable_adasum is True - - -@pytest.mark.parametrize( - "test_input", - [ - ( - { - "distributed": {"enable_adasum": False}, - } - ) - ], -) -def testORTTrainerOptionsDisabledAdasumFlag(test_input): - """Test the enabled_adasum flag values when set disabled""" - - actual_values = orttrainer_options.ORTTrainerOptions(test_input) - assert actual_values.distributed.enable_adasum is False - - -def testORTTrainerUnusedInput(): - class UnusedInputModel(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.mean(x) - - model = UnusedInputModel() - model_desc = {"inputs": [("x", [1]), ("y", [1])], "outputs": [("loss", [], True)]} - optim_config = optim.LambConfig(lr=0.001) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config) - - # Run just one step to make sure there are no iobinding errors for the unused input. - try: - trainer.train_step(torch.FloatTensor([1.0]), torch.FloatTensor([1.0])) - except RuntimeError: - pytest.fail("RuntimeError doing train_step with an unused input.") - - -@pytest.mark.parametrize( - "debug_files", - [ - { - "model_after_graph_transforms_path": "transformed.onnx", - "model_with_gradient_graph_path": "transformed_grad.onnx", - "model_with_training_graph_path": "training.onnx", - "model_with_training_graph_after_optimization_path": "training_optimized.onnx", - }, - {"model_after_graph_transforms_path": "transformed.onnx", "model_with_training_graph_path": ""}, - ], -) -def testTrainingGraphExport(debug_files): - device = "cuda" - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - - with tempfile.TemporaryDirectory() as tempdir: - debug_paths = {} - for k, v in debug_files.items(): - debug_paths[k] = os.path.join(tempdir, v) - opts = orttrainer.ORTTrainerOptions({"device": {"id": device}, "debug": {"graph_save_paths": debug_paths}}) - optim_config = optim.AdamConfig() - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - data, targets = batcher_fn(train_data, 0) - trainer.train_step(data, targets) - for k, v in debug_files.items(): - path = debug_paths[k] - if len(v) > 0: - assert os.path.isfile(path) - saved_graph = onnx.load(path).graph - if k == "model_with_training_graph_path": - assert any("AdamOptimizer" in n.op_type for n in saved_graph.node) - elif k == "model_with_gradient_graph_path": - assert any("Grad" in n.name for n in saved_graph.node) - elif k == "model_after_graph_transforms_path": - assert any("LayerNormalization" in n.op_type for n in saved_graph.node) - elif k == "model_with_training_graph_after_optimization_path": - assert any("FusedMatMul" in n.op_type for n in saved_graph.node) - # remove saved file - os.remove(path) - else: - assert not os.path.isfile(path) - - -def _adam_max_norm_clip_data(): - device_capability_major = torch.cuda.get_device_capability()[0] - if device_capability_major == 7: # V100 for Dev machine - return [ - ( - 0, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.592951, - 10.067989, - 9.619152, - 9.245731, - 8.881137, - 8.578644, - 8.280573, - 8.063023, - 7.797933, - 7.486215, - 7.233806, - 7.011791, - ], - 14: [ - 10.584141, - 10.068119, - 9.581743, - 9.191472, - 8.880169, - 8.5352, - 8.311425, - 8.061202, - 7.773032, - 7.523009, - 7.258711, - 7.02805, - ], - }, - ), - ( - 0, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.592951, - 10.068722, - 9.620503, - 9.247791, - 8.883972, - 8.582286, - 8.285027, - 8.068308, - 7.803638, - 7.492318, - 7.240352, - 7.018665, - ], - 14: [ - 10.584141, - 10.068845, - 9.583107, - 9.193537, - 8.882966, - 8.538839, - 8.315872, - 8.066408, - 7.778978, - 7.529708, - 7.265849, - 7.035439, - ], - }, - ), - ( - 42, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.647908, - 10.144501, - 9.672352, - 9.306980, - 8.956026, - 8.602655, - 8.351079, - 8.088144, - 7.867220, - 7.564082, - 7.289846, - 7.073726, - ], - 14: [ - 10.697515, - 10.229034, - 9.765422, - 9.428294, - 9.080612, - 8.715208, - 8.459574, - 8.169073, - 7.940211, - 7.654147, - 7.390446, - 7.166227, - ], - }, - ), - ( - 42, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.647908, - 10.145191, - 9.673690, - 9.309031, - 8.959020, - 8.606632, - 8.355836, - 8.093478, - 7.873327, - 7.570731, - 7.296772, - 7.0809422, - ], - 14: [ - 10.697515, - 10.22967, - 9.766556, - 9.430037, - 9.083106, - 8.718601, - 8.463726, - 8.17396, - 7.945755, - 7.660188, - 7.396963, - 7.172944, - ], - }, - ), - ] - elif device_capability_major == 5: # M60 for CI machines (Python Packaging Pipeline) - return [ - ( - 0, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.618382, - 10.08292, - 9.603334, - 9.258133, - 8.917768, - 8.591574, - 8.318401, - 8.042292, - 7.783608, - 7.50226, - 7.236041, - 7.035602, - ], - 14: [ - 10.618382, - 10.08292, - 9.603334, - 9.258133, - 8.917768, - 8.591574, - 8.318401, - 8.042292, - 7.783608, - 7.50226, - 7.236041, - 7.035602, - ], - }, - ), - ( - 0, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.618382, - 10.083632, - 9.604639, - 9.260109, - 8.920504, - 8.595082, - 8.322799, - 8.047493, - 7.78929, - 7.508382, - 7.242587, - 7.042367, - ], - 14: [ - 10.618382, - 10.083632, - 9.604639, - 9.260109, - 8.920504, - 8.595082, - 8.322799, - 8.047493, - 7.78929, - 7.508382, - 7.242587, - 7.042367, - ], - }, - ), - ( - 42, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.68639, - 10.102986, - 9.647681, - 9.293091, - 8.958928, - 8.625297, - 8.351107, - 8.079577, - 7.840723, - 7.543044, - 7.284141, - 7.072688, - ], - 14: [ - 10.68639, - 10.102986, - 9.647681, - 9.293091, - 8.958928, - 8.625297, - 8.351107, - 8.079577, - 7.840723, - 7.543044, - 7.284141, - 7.072688, - ], - }, - ), - ( - 42, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.68639, - 10.103672, - 9.649025, - 9.295167, - 8.961777, - 8.629059, - 8.355571, - 8.084871, - 7.846589, - 7.549438, - 7.290722, - 7.079446, - ], - 14: [ - 10.697515, - 10.22967, - 9.766556, - 9.430037, - 9.083106, - 8.718601, - 8.463726, - 8.17396, - 7.945755, - 7.660188, - 7.396963, - 7.172944, - ], - }, - ), - ] - - -@pytest.mark.parametrize( - "seed,device,max_norm_clip,gradient_accumulation_steps,total_steps,expected_loss", _adam_max_norm_clip_data() -) -def testORTTrainerAdamMaxNormClip(seed, device, max_norm_clip, gradient_accumulation_steps, total_steps, expected_loss): - rtol = 1e-5 - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.AdamConfig(lr=0.001, max_norm_clip=max_norm_clip) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu().item()) - - # Compare legacy vs experimental APIs - assert trainer._onnx_model is not None - opset = get_model_opset(trainer._onnx_model) - _test_helpers.assert_model_outputs(expected_loss[opset], actual_loss, rtol=rtol) - - -def _lamb_max_norm_clip_data(): - device_capability_major = torch.cuda.get_device_capability()[0] - if device_capability_major == 7: # V100 for Dev machine - return [ - ( - 0, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.592951, - 10.487728, - 10.422251, - 10.350913, - 10.244248, - 10.213003, - 10.129222, - 10.095112, - 10.035983, - 9.974586, - 9.909771, - 9.874278, - ], - 14: [ - 10.584141, - 10.497192, - 10.389251, - 10.286045, - 10.231354, - 10.17018, - 10.066779, - 10.048138, - 9.958029, - 9.8908, - 9.82965, - 9.755484, - ], - }, - ), - ( - 0, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.592951, - 10.452503, - 10.349832, - 10.245314, - 10.106587, - 10.046009, - 9.934781, - 9.875164, - 9.792067, - 9.704592, - 9.617104, - 9.563070, - ], - 14: [ - 10.584141, - 10.461154, - 10.315399, - 10.178979, - 10.092329, - 9.999928, - 9.869949, - 9.824564, - 9.707565, - 9.61643, - 9.532847, - 9.439593, - ], - }, - ), - ( - 42, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.647908, - 10.566276, - 10.476154, - 10.406275, - 10.311079, - 10.240053, - 10.196469, - 10.113955, - 10.117376, - 10.013077, - 9.930301, - 9.893368, - ], - 14: [ - 10.697515, - 10.631279, - 10.528757, - 10.496689, - 10.411219, - 10.322109, - 10.297314, - 10.215549, - 10.149698, - 10.087336, - 10.010884, - 9.934544, - ], - }, - ), - ( - 42, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.647908, - 10.531957, - 10.405246, - 10.302971, - 10.176583, - 10.075583, - 10.005772, - 9.897825, - 9.875748, - 9.748932, - 9.642885, - 9.586762, - ], - 14: [ - 10.697515, - 10.596729, - 10.457815, - 10.393475, - 10.277581, - 10.158909, - 10.108126, - 10.000326, - 9.912526, - 9.826057, - 9.727899, - 9.633768, - ], - }, - ), - ] - elif device_capability_major == 5: # M60 for CI machines (Python Packaging Pipeline) - return [ - ( - 0, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.618382, - 10.50222, - 10.403347, - 10.35298, - 10.288447, - 10.237399, - 10.184225, - 10.089048, - 10.008952, - 9.972644, - 9.897674, - 9.84524, - ], - 14: [0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], - }, - ), - ( - 0, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.618382, - 10.466732, - 10.330871, - 10.24715, - 10.150972, - 10.069127, - 9.98974, - 9.870169, - 9.763693, - 9.704323, - 9.605957, - 9.533117, - ], - 14: [1, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], - }, - ), - ( - 42, - "cuda", - 1.0, - 1, - 12, - { - 12: [ - 10.68639, - 10.511692, - 10.447308, - 10.405255, - 10.334866, - 10.261473, - 10.169422, - 10.107138, - 10.069889, - 9.97798, - 9.928105, - 9.896435, - ], - 14: [2, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], - }, - ), - ( - 42, - "cuda", - 0.1, - 1, - 12, - { - 12: [ - 10.68639, - 10.477489, - 10.376671, - 10.301725, - 10.200718, - 10.098477, - 9.97995, - 9.890104, - 9.828899, - 9.713555, - 9.639567, - 9.589856, - ], - 14: [3, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], - }, - ), - ] - - -@pytest.mark.parametrize( - "seed,device,max_norm_clip, gradient_accumulation_steps,total_steps,expected_loss", _lamb_max_norm_clip_data() -) -def testORTTrainerLambMaxNormClip(seed, device, max_norm_clip, gradient_accumulation_steps, total_steps, expected_loss): - rtol = 1e-3 - torch.manual_seed(seed) - set_seed(seed) - - # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "debug": {"deterministic_compute": True}, - } - ) - model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) - optim_config = optim.LambConfig(lr=0.001, max_norm_clip=max_norm_clip) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) - - # Training loop - actual_loss = [] - for i in range(total_steps): - data, targets = batcher_fn(train_data, i) - loss, _ = trainer.train_step(data, targets) - actual_loss.append(loss.cpu().item()) - - # Compare legacy vs experimental APIs - opset = get_model_opset(trainer._onnx_model) - _test_helpers.assert_model_outputs(expected_loss[opset], actual_loss, rtol=rtol) diff --git a/orttraining/orttraining/test/python/orttraining_test_transformers.py b/orttraining/orttraining/test/python/orttraining_test_transformers.py deleted file mode 100644 index dbaf4a293c466..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_transformers.py +++ /dev/null @@ -1,480 +0,0 @@ -import random -import unittest - -import numpy as np -import torch -from numpy.testing import assert_allclose -from orttraining_test_data_loader import BatchArgsOption, ids_tensor -from orttraining_test_utils import get_lr, run_test -from transformers import BertConfig, BertForPreTraining - -import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription, LossScaler, ModelDescription, ORTTrainer # noqa: F401 - - -class BertModelTest(unittest.TestCase): - class BertModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=True, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=5, - num_attention_heads=4, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - scope=None, - device="cpu", - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.scope = scope - self.device = device - - # 1. superset of bert input/output descs - # see BertPreTrainedModel doc - self.input_ids_desc = IODescription( - "input_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.vocab_size - ) - self.attention_mask_desc = IODescription( - "attention_mask", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2 - ) - self.token_type_ids_desc = IODescription( - "token_type_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2 - ) - self.position_ids_desc = IODescription( - "position_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.max_position_embeddings - ) - self.head_mask_desc = IODescription( - "head_mask", [self.num_hidden_layers, self.num_attention_heads], torch.int64, num_classes=2 - ) - self.inputs_embeds_desc = IODescription( - "inputs_embeds", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32 - ) - - self.encoder_hidden_states_desc = IODescription( - "encoder_hidden_states", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32 - ) - self.encoder_attention_mask_desc = IODescription( - "encoder_attention_mask", ["batch", "max_seq_len_in_batch"], torch.float32 - ) - - # see BertForPreTraining doc - self.masked_lm_labels_desc = IODescription( - "masked_lm_labels", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.vocab_size - ) - self.next_sentence_label_desc = IODescription( - "next_sentence_label", - [ - "batch", - ], - torch.int64, - num_classes=2, - ) - - # outputs - self.loss_desc = IODescription( - "loss", - [ - 1, - ], - torch.float32, - ) - self.prediction_scores_desc = IODescription( - "prediction_scores", ["batch", "max_seq_len_in_batch", self.vocab_size], torch.float32 - ) - - self.seq_relationship_scores_desc = IODescription( - "seq_relationship_scores", ["batch", 2], torch.float32 - ) # IODescription('seq_relationship_scores', ['batch', 'max_seq_len_in_batch', 2], torch.float32) - self.hidden_states_desc = IODescription( - "hidden_states", - [self.num_hidden_layers, "batch", "max_seq_len_in_batch", self.hidden_size], - torch.float32, - ) - self.attentions_desc = IODescription( - "attentions", - [ - self.num_hidden_layers, - "batch", - self.num_attention_heads, - "max_seq_len_in_batch", - "max_seq_len_in_batch", - ], - torch.float32, - ) - self.last_hidden_state_desc = IODescription( - "last_hidden_state", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32 - ) - self.pooler_output_desc = IODescription("pooler_output", ["batch", self.hidden_size], torch.float32) - - def BertForPreTraining_descs(self): - return ModelDescription( - [ - self.input_ids_desc, - self.attention_mask_desc, - self.token_type_ids_desc, - self.masked_lm_labels_desc, - self.next_sentence_label_desc, - ], - # returns loss_desc if both masked_lm_labels_desc, next_sentence_label are provided - # hidden_states_desc, attentions_desc shall be included according to config.output_attentions, config.output_hidden_states - [ - self.loss_desc, - self.prediction_scores_desc, - self.seq_relationship_scores_desc, - # hidden_states_desc, attentions_desc - ], - ) - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).to(self.device) - - input_mask = None - if self.use_input_mask: - input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2).to(self.device) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size).to(self.device) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size).to(self.device) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels).to(self.device) - choice_labels = ids_tensor([self.batch_size], self.num_choices).to(self.device) - - config = BertConfig( - vocab_size=self.vocab_size, - vocab_size_or_config_json_file=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - ) - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def create_and_check_bert_for_pretraining( - self, - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - option_use_internal_get_lr_this_step=[True], # noqa: B006 - option_use_internal_loss_scaler=[True], # noqa: B006 - ): - seed = 42 - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - onnxruntime.set_seed(seed) - - model = BertForPreTraining(config=config) - model.eval() - loss, prediction_scores, seq_relationship_score = model( - input_ids, - attention_mask=input_mask, - token_type_ids=token_type_ids, - masked_lm_labels=token_labels, - next_sentence_label=sequence_labels, - ) - model_desc = ModelDescription( - [ - self.input_ids_desc, - self.attention_mask_desc, - self.token_type_ids_desc, - self.masked_lm_labels_desc, - self.next_sentence_label_desc, - ], - [self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc], - ) - - from collections import namedtuple - - MyArgs = namedtuple( - "MyArgs", "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len" - ) - - dataset_len = 100 - epochs = 8 - max_steps = epochs * dataset_len - args = MyArgs( - local_rank=0, - world_size=1, - max_steps=max_steps, - learning_rate=0.00001, - warmup_proportion=0.01, - batch_size=13, - seq_len=7, - ) - - def get_lr_this_step(global_step): - return get_lr(args, global_step) - - loss_scaler = LossScaler("loss_scale_input_name", True, up_scale_window=2000) - - for fp16 in option_fp16: - for allreduce_post_accumulation in option_allreduce_post_accumulation: - for gradient_accumulation_steps in option_gradient_accumulation_steps: - for use_internal_get_lr_this_step in option_use_internal_get_lr_this_step: - for use_internal_loss_scaler in option_use_internal_loss_scaler: - for split_batch in option_split_batch: - print("gradient_accumulation_steps:", gradient_accumulation_steps) - print("split_batch:", split_batch) - - seed = 42 - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - onnxruntime.set_seed(seed) - - ( - old_api_loss_ort, - old_api_prediction_scores_ort, - old_api_seq_relationship_score_ort, - ) = run_test( - model, - model_desc, - self.device, - args, - gradient_accumulation_steps, - fp16, - allreduce_post_accumulation, - get_lr_this_step, - use_internal_get_lr_this_step, - loss_scaler, - use_internal_loss_scaler, - split_batch, - dataset_len, - epochs, - use_new_api=False, - ) - - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - onnxruntime.set_seed(seed) - if use_internal_get_lr_this_step and use_internal_loss_scaler: - ( - new_api_loss_ort, - new_api_prediction_scores_ort, - new_api_seq_relationship_score_ort, - ) = run_test( - model, - model_desc, - self.device, - args, - gradient_accumulation_steps, - fp16, - allreduce_post_accumulation, - get_lr_this_step, - use_internal_get_lr_this_step, - loss_scaler, - use_internal_loss_scaler, - split_batch, - dataset_len, - epochs, - use_new_api=True, - ) - - assert_allclose(old_api_loss_ort, new_api_loss_ort) - assert_allclose(old_api_prediction_scores_ort, new_api_prediction_scores_ort) - assert_allclose( - old_api_seq_relationship_score_ort, new_api_seq_relationship_score_ort - ) - - def setUp(self): - self.model_tester = BertModelTest.BertModelTester(self) - - def test_for_pretraining_mixed_precision(self): - # It would be better to test both with/without mixed precision and allreduce_post_accumulation. - # However, stress test of all the 4 cases is not stable at least on the test machine. - # There we only test mixed precision and allreduce_post_accumulation because it is the most useful use cases. - option_fp16 = [True] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1] - option_split_batch = [BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_mixed_precision_with_gradient_accumulation(self): - # It would be better to test both with/without mixed precision and allreduce_post_accumulation. - # However, stress test of all the 4 cases is not stable at least on the test machine. - # There we only test mixed precision and allreduce_post_accumulation because it is the most useful use cases. - option_fp16 = [True] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [8] - option_split_batch = [BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_all(self): - # This test is not stable because it create and run ORTSession multiple times. - # It occasionally gets seg fault at ~MemoryPattern() - # when releasing patterns_. In order not to block PR merging CI test, - # this test is broke into following individual tests. - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1, 8] - option_split_batch = [BatchArgsOption.List, BatchArgsOption.Dict, BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_list_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1] - option_split_batch = [BatchArgsOption.List] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_dict_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1] - option_split_batch = [BatchArgsOption.Dict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_list_and_dict_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [1] - option_split_batch = [BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_grad_accumulation_list_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [8] - option_split_batch = [BatchArgsOption.List] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_grad_accumulation_dict_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [8] - option_split_batch = [BatchArgsOption.Dict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - def test_for_pretraining_full_precision_grad_accumulation_list_and_dict_input(self): - option_fp16 = [False] - option_allreduce_post_accumulation = [True] - option_gradient_accumulation_steps = [8] - option_split_batch = [BatchArgsOption.ListAndDict] - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_bert_for_pretraining( - *config_and_inputs, - option_fp16, - option_allreduce_post_accumulation, - option_gradient_accumulation_steps, - option_split_batch, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_test_utils.py b/orttraining/orttraining/test/python/orttraining_test_utils.py deleted file mode 100644 index 527cfb8a0ba7d..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_test_utils.py +++ /dev/null @@ -1,246 +0,0 @@ -import math - -import torch -from orttraining_test_data_loader import BatchArgsOption, create_ort_test_dataloader, split_batch - -from onnxruntime.capi.ort_trainer import IODescription, ORTTrainer -from onnxruntime.training import amp, optim, orttrainer -from onnxruntime.training.optim import _LRScheduler - - -def warmup_cosine(x, warmup=0.002): - if x < warmup: - return x / warmup - return 0.5 * (1.0 + torch.cos(math.pi * x)) - - -def warmup_constant(x, warmup=0.002): - if x < warmup: - return x / warmup - return 1.0 - - -def warmup_linear(x, warmup=0.002): - if x < warmup: - return x / warmup - return max((x - 1.0) / (warmup - 1.0), 0.0) - - -def warmup_poly(x, warmup=0.002, degree=0.5): - if x < warmup: - return x / warmup - return (1.0 - x) ** degree - - -SCHEDULES = { - "warmup_cosine": warmup_cosine, - "warmup_constant": warmup_constant, - "warmup_linear": warmup_linear, - "warmup_poly": warmup_poly, -} - - -def get_lr(args, training_steps, schedule="warmup_poly"): - if args.max_steps == -1: - return args.learning_rate - - schedule_fct = SCHEDULES[schedule] - return args.learning_rate * schedule_fct(training_steps / args.max_steps, args.warmup_proportion) - - -def map_optimizer_attributes(name): - no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] - no_decay = any(no_decay_key in name for no_decay_key in no_decay_keys) - if no_decay: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} - else: - return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} - - -class WrapLRScheduler(_LRScheduler): - def __init__(self, get_lr_this_step): - super().__init__() - self.get_lr_this_step = get_lr_this_step - - def get_lr(self, train_step_info): - return [self.get_lr_this_step(train_step_info.optimization_step)] - - -def run_test( - model, - model_desc, - device, - args, - gradient_accumulation_steps, - fp16, - allreduce_post_accumulation, - get_lr_this_step, - use_internal_get_lr_this_step, - loss_scaler, - use_internal_loss_scaler, - batch_args_option, - dataset_len, - epochs, - use_new_api, -): - dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device) - - if use_new_api: - assert use_internal_loss_scaler, "new api should always use internal loss scaler" - - new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step) - - new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None - options = orttrainer.ORTTrainerOptions( - { - "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, - "device": {"id": device}, - "mixed_precision": {"enabled": fp16, "loss_scaler": new_api_loss_scaler}, - "debug": { - "deterministic_compute": True, - }, - "utils": {"grad_norm_clip": True}, - "distributed": {"allreduce_post_accumulation": True}, - "lr_scheduler": new_api_lr_scheduler, - } - ) - - param_optimizer = list(model.named_parameters()) - params = [ - { - "params": [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n], - "alpha": 0.9, - "beta": 0.999, - "lambda": 0.0, - "epsilon": 1e-6, - }, - { - "params": [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)], - "alpha": 0.9, - "beta": 0.999, - "lambda": 0.0, - "epsilon": 1e-6, - }, - ] - - vocab_size = 99 - new_model_desc = { - "inputs": [ - ( - "input_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "attention_mask", - ["batch", "max_seq_len_in_batch"], - ), - ( - "token_type_ids", - ["batch", "max_seq_len_in_batch"], - ), - ( - "masked_lm_labels", - ["batch", "max_seq_len_in_batch"], - ), - ( - "next_sentence_label", - [ - "batch", - ], - ), - ], - "outputs": [ - ( - "loss", - [ - 1, - ], - True, - ), - ("prediction_scores", ["batch", "max_seq_len_in_batch", vocab_size]), - ("seq_relationship_scores", ["batch", 2]), - ], - } - - optim_config = optim.LambConfig(params=params, lr=2e-5) - model = orttrainer.ORTTrainer(model, new_model_desc, optim_config, options=options) - print("running with new frontend API") - else: - model = ORTTrainer( - model, - None, - model_desc, - "LambOptimizer", - map_optimizer_attributes=map_optimizer_attributes, - learning_rate_description=IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ), - device=device, - _enable_internal_postprocess=True, - gradient_accumulation_steps=gradient_accumulation_steps, - # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6 - world_rank=args.local_rank, - world_size=args.world_size, - use_mixed_precision=fp16, - allreduce_post_accumulation=allreduce_post_accumulation, - get_lr_this_step=get_lr_this_step if use_internal_get_lr_this_step else None, - loss_scaler=loss_scaler if use_internal_loss_scaler else None, - _opset_version=14, - _use_deterministic_compute=True, - ) - print("running with old frontend API") - - # training loop - eval_batch = None - if not use_new_api: - model.train() - for _epoch in range(epochs): - for step, batch in enumerate(dataloader): - if eval_batch is None: - eval_batch = batch - - if not use_internal_get_lr_this_step: - lr = get_lr_this_step(step) - learning_rate = torch.tensor([lr]) - - if not use_internal_loss_scaler and fp16: - loss_scale = torch.tensor([loss_scaler.loss_scale_]) - - if batch_args_option == BatchArgsOption.List: - if not use_internal_get_lr_this_step: - batch = [*batch, learning_rate] # noqa: PLW2901 - if not use_internal_loss_scaler and fp16: - batch = [*batch, loss_scale] # noqa: PLW2901 - outputs = model.train_step(*batch) - elif batch_args_option == BatchArgsOption.Dict: - args, kwargs = split_batch(batch, model_desc.inputs_, 0) - if not use_internal_get_lr_this_step: - kwargs["Learning_Rate"] = learning_rate - if not use_internal_loss_scaler and fp16: - kwargs[model.loss_scale_input_name] = loss_scale - outputs = model.train_step(*args, **kwargs) - else: - args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs - args, kwargs = split_batch(batch, model_desc.inputs_, args_count) - if not use_internal_get_lr_this_step: - kwargs["Learning_Rate"] = learning_rate - if not use_internal_loss_scaler and fp16: - kwargs[model.loss_scale_input_name] = loss_scale - outputs = model.train_step(*args, **kwargs) - - # eval - if batch_args_option == BatchArgsOption.List: - outputs = model.eval_step(*batch) - elif batch_args_option == BatchArgsOption.Dict: - args, kwargs = split_batch(batch, model_desc.inputs_, 0) - outputs = model.eval_step(*args, **kwargs) - else: - args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs - args, kwargs = split_batch(batch, model_desc.inputs_, args_count) - outputs = model.eval_step(*args, **kwargs) - - return (output.cpu().numpy() for output in outputs) diff --git a/orttraining/orttraining/test/python/orttraining_transformer_trainer.py b/orttraining/orttraining/test/python/orttraining_transformer_trainer.py deleted file mode 100644 index bce726871bacf..0000000000000 --- a/orttraining/orttraining/test/python/orttraining_transformer_trainer.py +++ /dev/null @@ -1,357 +0,0 @@ -# adapted from Trainer.py of huggingface transformers - -import json -import logging -import os -import random -from typing import Callable, Dict, List, NamedTuple, Optional - -import numpy as np -import torch -from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import Dataset -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import SequentialSampler -from tqdm import tqdm, trange -from transformers.data.data_collator import DefaultDataCollator -from transformers.modeling_utils import PreTrainedModel -from transformers.training_args import TrainingArguments - -import onnxruntime -from onnxruntime.training import amp, optim, orttrainer - -try: - from torch.utils.tensorboard import SummaryWriter - - _has_tensorboard = True -except ImportError: - try: - from tensorboardX import SummaryWriter # noqa: F401 - - _has_tensorboard = True - except ImportError: - _has_tensorboard = False - - -def is_tensorboard_available(): - return _has_tensorboard - - -logger = logging.getLogger(__name__) - - -def set_seed(seed: int): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - onnxruntime.set_seed(seed) - - -class EvalPrediction(NamedTuple): - predictions: np.ndarray - label_ids: np.ndarray - - -class PredictionOutput(NamedTuple): - predictions: np.ndarray - label_ids: Optional[np.ndarray] - metrics: Optional[Dict[str, float]] - - -class TrainOutput(NamedTuple): - global_step: int - training_loss: float - - -def get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps, base_lr): - def lr_lambda_linear(current_step): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) - - def lambda_lr_get_lr(current_global_step): - # LambdaLR increment self.last_epoch at evert sept() - return base_lr * lr_lambda_linear(current_global_step) - - return lambda_lr_get_lr - - -class ORTTransformerTrainer: - """ """ - - model: PreTrainedModel - args: TrainingArguments - train_dataset: Dataset - eval_dataset: Dataset - compute_metrics: Callable[[EvalPrediction], Dict] - - def __init__( - self, - model: PreTrainedModel, - model_desc: dict, - args: TrainingArguments, - train_dataset: Dataset, - eval_dataset: Dataset, - compute_metrics: Callable[[EvalPrediction], Dict], - world_size: Optional[int] = 1, - ): - """ """ - - self.model = model - self.model_desc = model_desc - self.args = args - self.world_size = world_size - self.data_collator = DefaultDataCollator() - self.train_dataset = train_dataset - self.eval_dataset = eval_dataset - self.compute_metrics = compute_metrics - set_seed(self.args.seed) - # Create output directory if needed - if self.args.local_rank in [-1, 0]: - os.makedirs(self.args.output_dir, exist_ok=True) - - def get_train_dataloader(self) -> DataLoader: - if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") - train_sampler = ( - SequentialSampler(self.train_dataset) - if self.args.local_rank == -1 - else DistributedSampler(self.train_dataset) - ) - return DataLoader( - self.train_dataset, - batch_size=self.args.train_batch_size, - sampler=train_sampler, - collate_fn=self.data_collator.collate_batch, - ) - - def get_eval_dataloader(self) -> DataLoader: - return DataLoader( - self.eval_dataset, - batch_size=self.args.eval_batch_size, - shuffle=False, - collate_fn=self.data_collator.collate_batch, - ) - - def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: - # We use the same batch_size as for eval. - return DataLoader( - test_dataset, - batch_size=self.args.eval_batch_size, - shuffle=False, - collate_fn=self.data_collator.collate_batch, - ) - - def train(self): - """ - Main training entry point. - """ - train_dataloader = self.get_train_dataloader() - - if self.args.max_steps > 0: - t_total = self.args.max_steps - num_train_epochs = ( - self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1 - ) - else: - t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) - num_train_epochs = self.args.num_train_epochs - - lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(t_total, self.args.warmup_steps / float(t_total)) - - loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None - device = self.args.device.type - - device = f"{device}:{self.args.device.index}" if self.args.device.index else f"{device}:0" - options = orttrainer.ORTTrainerOptions( - { - "batch": {"gradient_accumulation_steps": self.args.gradient_accumulation_steps}, - "device": {"id": device}, - "mixed_precision": {"enabled": self.args.fp16, "loss_scaler": loss_scaler}, - "debug": { - "deterministic_compute": True, - }, - "utils": {"grad_norm_clip": False}, - "distributed": { - # we are running single node multi gpu test. thus world_rank = local_rank - # and world_size = self.args.n_gpu - "world_rank": max(0, self.args.local_rank), - "world_size": int(self.world_size), - "local_rank": max(0, self.args.local_rank), - "allreduce_post_accumulation": True, - }, - "lr_scheduler": lr_scheduler, - } - ) - - param_optimizer = list(self.model.named_parameters()) - params = [ - { - "params": [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n], - "weight_decay_mode": 1, - }, - { - "params": [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)], - "weight_decay_mode": 1, - }, - ] - - optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) - self.model = orttrainer.ORTTrainer(self.model, self.model_desc, optim_config, options=options) - - # Train! - logger.info("***** Running training *****") - logger.info(" Num examples = %d", len(train_dataloader.dataset)) - logger.info(" Num Epochs = %d", num_train_epochs) - logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size) - logger.info( - " Total train batch size (w. parallel, distributed & accumulation) = %d", - self.args.train_batch_size - * self.args.gradient_accumulation_steps - * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1), - ) - logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) - logger.info(" Total optimization steps = %d", t_total) - - global_step = 0 - epochs_trained = 0 - steps_trained_in_current_epoch = 0 - - tr_loss = 0.0 - logging_loss = 0.0 - train_iterator = trange( - epochs_trained, - int(num_train_epochs), - desc="Epoch", - disable=self.args.local_rank not in [-1, 0], - ) - - for _epoch in train_iterator: - epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0]) - for step, inputs in enumerate(epoch_iterator): - # Skip past any already trained steps if resuming training - if steps_trained_in_current_epoch > 0: - steps_trained_in_current_epoch -= 1 - continue - - tr_loss += self._training_step(self.model, inputs) - - if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( - len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator) - ): - global_step += 1 - - if self.args.local_rank in [-1, 0]: - if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or ( - global_step == 1 and self.args.logging_first_step - ): - logs = {} - if self.args.evaluate_during_training: - results = self.evaluate() - for key, value in results.items(): - eval_key = f"eval_{key}" - logs[eval_key] = value - - loss_scalar = (tr_loss - logging_loss) / self.args.logging_steps - - logs["loss"] = loss_scalar - logging_loss = tr_loss - - epoch_iterator.write(json.dumps({**logs, **{"step": global_step}})) - - if self.args.max_steps > 0 and global_step > self.args.max_steps: - epoch_iterator.close() - break - if self.args.max_steps > 0 and global_step > self.args.max_steps: - train_iterator.close() - break - - logger.info("\n\nTraining completed. \n\n") - return TrainOutput(global_step, tr_loss / global_step) - - def _training_step(self, model, inputs: Dict[str, torch.Tensor]) -> float: - for k, v in inputs.items(): - inputs[k] = v.to(self.args.device) - - outputs = model.train_step(**inputs) - loss = outputs[0] # model outputs are always tuple in transformers (see doc) - - return loss.item() - - def save_model(self, output_dir: Optional[str] = None): - output_dir = output_dir if output_dir is not None else self.args.output_dir - os.makedirs(output_dir, exist_ok=True) - self.model.save_as_onnx(os.path.join(output_dir, "transformer.onnx")) - - def evaluate(self) -> Dict[str, float]: - """ - Run evaluation and return metrics. - - Returns: - A dict containing: - - the eval loss - - the potential metrics computed from the predictions - """ - eval_dataloader = self.get_eval_dataloader() - - output = self._prediction_loop(eval_dataloader, description="Evaluation") - return output.metrics - - def predict(self, test_dataset: Dataset) -> PredictionOutput: - """ - Run prediction and return predictions and potential metrics. - - Depending on the dataset and your use case, your test dataset may contain labels. - In that case, this method will also return metrics, like in evaluate(). - """ - test_dataloader = self.get_test_dataloader(test_dataset) - return self._prediction_loop(test_dataloader, description="Prediction") - - def _prediction_loop(self, dataloader: DataLoader, description: str) -> PredictionOutput: - """ - Prediction/evaluation loop, shared by `evaluate()` and `predict()`. - - Works both with or without labels. - """ - - logger.info("***** Running %s *****", description) - logger.info(" Num examples = %d", len(dataloader.dataset)) - logger.info(" Batch size = %d", dataloader.batch_size) - eval_losses: List[float] = [] - preds: np.ndarray = None - label_ids: np.ndarray = None - - for inputs in tqdm(dataloader, desc=description): - has_labels = any(inputs.get(k) is not None for k in ["labels", "masked_lm_labels"]) - - for k, v in inputs.items(): - inputs[k] = v.to(self.args.device) - - with torch.no_grad(): - outputs = self.model.eval_step(**inputs) - - if has_labels: - step_eval_loss, logits = outputs[:2] - eval_losses += [step_eval_loss.mean().item()] - else: - logits = outputs[0] - - if preds is None: - preds = logits.detach().cpu().numpy() - else: - preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) - if inputs.get("labels") is not None: - if label_ids is None: - label_ids = inputs["labels"].detach().cpu().numpy() - else: - label_ids = np.append(label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) - - if self.compute_metrics is not None and preds is not None and label_ids is not None: - metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) - else: - metrics = {} - if len(eval_losses) > 0: - metrics["loss"] = np.mean(eval_losses) - - return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) diff --git a/orttraining/orttraining/test/python/utils_multiple_choice.py b/orttraining/orttraining/test/python/utils_multiple_choice.py deleted file mode 100644 index e0febaf2d6334..0000000000000 --- a/orttraining/orttraining/test/python/utils_multiple_choice.py +++ /dev/null @@ -1,269 +0,0 @@ -# adapted from run_multiple_choice.py of huggingface transformers -# https://github.com/huggingface/transformers/blob/master/examples/multiple-choice/utils_multiple_choice.py - -import csv -import glob # noqa: F401 -import json # noqa: F401 -import logging -import os -from dataclasses import dataclass -from enum import Enum -from typing import List, Optional - -import torch -import tqdm -from filelock import FileLock -from torch.utils.data.dataset import Dataset -from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available # noqa: F401 - -logger = logging.getLogger(__name__) - - -@dataclass(frozen=True) -class InputExample: - """ - A single training/test example for multiple choice - - Args: - example_id: Unique id for the example. - question: string. The untokenized text of the second sequence (question). - contexts: list of str. The untokenized text of the first sequence (context of corresponding question). - endings: list of str. multiple choice's options. Its length must be equal to contexts' length. - label: (Optional) string. The label of the example. This should be - specified for train and dev examples, but not for test examples. - """ - - example_id: str - question: str - contexts: List[str] - endings: List[str] - label: Optional[str] - - -@dataclass(frozen=True) -class InputFeatures: - """ - A single set of features of data. - Property names are the same names as the corresponding inputs to a model. - """ - - example_id: str - input_ids: List[List[int]] - attention_mask: Optional[List[List[int]]] - token_type_ids: Optional[List[List[int]]] - label: Optional[int] - - -class Split(Enum): - train = "train" - dev = "dev" - test = "test" - - -class DataProcessor: - """Base class for data converters for multiple choice data sets.""" - - def get_train_examples(self, data_dir): - """Gets a collection of `InputExample`s for the train set.""" - raise NotImplementedError() - - def get_dev_examples(self, data_dir): - """Gets a collection of `InputExample`s for the dev set.""" - raise NotImplementedError() - - def get_test_examples(self, data_dir): - """Gets a collection of `InputExample`s for the test set.""" - raise NotImplementedError() - - def get_labels(self): - """Gets the list of labels for this data set.""" - raise NotImplementedError() - - -class MultipleChoiceDataset(Dataset): - """ - This will be superseded by a framework-agnostic approach - soon. - """ - - features: List[InputFeatures] - - def __init__( - self, - data_dir: str, - tokenizer: PreTrainedTokenizer, - task: str, - processor: DataProcessor, - max_seq_length: Optional[int] = None, - overwrite_cache=False, - mode: Split = Split.train, - ): - cached_features_file = os.path.join( - data_dir, - "cached_{}_{}_{}_{}".format( - mode.value, - tokenizer.__class__.__name__, - str(max_seq_length), - task, - ), - ) - - # Make sure only the first process in distributed training processes the dataset, - # and the others will use the cache. - lock_path = cached_features_file + ".lock" - with FileLock(lock_path): - if os.path.exists(cached_features_file) and not overwrite_cache: - logger.info(f"Loading features from cached file {cached_features_file}") - self.features = torch.load(cached_features_file) - else: - logger.info(f"Creating features from dataset file at {data_dir}") - label_list = processor.get_labels() - if mode == Split.dev: - examples = processor.get_dev_examples(data_dir) - elif mode == Split.test: - examples = processor.get_test_examples(data_dir) - else: - examples = processor.get_train_examples(data_dir) - logger.info("Training examples: %s", len(examples)) - # TODO clean up all this to leverage built-in features of tokenizers - self.features = convert_examples_to_features( - examples, - label_list, - max_seq_length, - tokenizer, - pad_on_left=bool(tokenizer.padding_side == "left"), - pad_token=tokenizer.pad_token_id, - pad_token_segment_id=tokenizer.pad_token_type_id, - ) - logger.info("Saving features into cached file %s", cached_features_file) - torch.save(self.features, cached_features_file) - - def __len__(self): - return len(self.features) - - def __getitem__(self, i) -> InputFeatures: - return self.features[i] - - -class SwagProcessor(DataProcessor): - """Processor for the SWAG data set.""" - - def get_train_examples(self, data_dir): - """See base class.""" - logger.info(f"LOOKING AT {data_dir} train") - return self._create_examples(self._read_csv(os.path.join(data_dir, "train.csv")), "train") - - def get_dev_examples(self, data_dir): - """See base class.""" - logger.info(f"LOOKING AT {data_dir} dev") - return self._create_examples(self._read_csv(os.path.join(data_dir, "val.csv")), "dev") - - def get_test_examples(self, data_dir): - """See base class.""" - logger.info(f"LOOKING AT {data_dir} dev") - raise ValueError( - "For swag testing, the input file does not contain a label column. It can not be tested in current code" - "setting!" - ) - return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test") - - def get_labels(self): - """See base class.""" - return ["0", "1", "2", "3"] - - def _read_csv(self, input_file): - with open(input_file, encoding="utf-8") as f: - return list(csv.reader(f)) - - def _create_examples(self, lines: List[List[str]], type: str): - """Creates examples for the training and dev sets.""" - if type == "train" and lines[0][-1] != "label": - raise ValueError("For training, the input file must contain a label column.") - - examples = [ - InputExample( - example_id=line[2], - question=line[5], # in the swag dataset, the - # common beginning of each - # choice is stored in "sent2". - contexts=[line[4], line[4], line[4], line[4]], - endings=[line[7], line[8], line[9], line[10]], - label=line[11], - ) - for line in lines[1:] # we skip the line with the column names - ] - - return examples - - -def convert_examples_to_features( - examples: List[InputExample], - label_list: List[str], - max_length: int, - tokenizer: PreTrainedTokenizer, - pad_token_segment_id=0, - pad_on_left=False, - pad_token=0, - mask_padding_with_zero=True, -) -> List[InputFeatures]: - """ - Loads a data file into a list of `InputFeatures` - """ - - label_map = {label: i for i, label in enumerate(label_list)} - - features = [] - for ex_index, example in tqdm.tqdm(enumerate(examples), desc="convert examples to features"): - if ex_index % 10000 == 0: - logger.info("Writing example %d of %d" % (ex_index, len(examples))) - choices_inputs = [] - for _ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)): - text_a = context - if example.question.find("_") != -1: - # this is for cloze question - text_b = example.question.replace("_", ending) - else: - text_b = example.question + " " + ending - - inputs = tokenizer.encode_plus( - text_a, - text_b, - add_special_tokens=True, - max_length=max_length, - pad_to_max_length=True, - return_overflowing_tokens=True, - ) - if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0: - logger.info( - "Attention! you are cropping tokens (swag task is ok). " - "If you are training ARC and RACE and you are poping question + options," - "you need to try to use a bigger max seq length!" - ) - - choices_inputs.append(inputs) - - label = label_map[example.label] - - input_ids = [x["input_ids"] for x in choices_inputs] - attention_mask = ( - [x["attention_mask"] for x in choices_inputs] if "attention_mask" in choices_inputs[0] else None - ) - token_type_ids = ( - [x["token_type_ids"] for x in choices_inputs] if "token_type_ids" in choices_inputs[0] else None - ) - - features.append( - InputFeatures( - example_id=example.example_id, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - label=label, - ) - ) - - for f in features[:2]: - logger.info("*** Example ***") - logger.info("feature: %s" % f) - - return features diff --git a/orttraining/pytorch_frontend_examples/mnist_training.py b/orttraining/pytorch_frontend_examples/mnist_training.py deleted file mode 100644 index dc9b3f654400c..0000000000000 --- a/orttraining/pytorch_frontend_examples/mnist_training.py +++ /dev/null @@ -1,200 +0,0 @@ -## This code is from https://github.com/pytorch/examples/blob/master/mnist/main.py -## with modification to do training using onnxruntime as backend on cuda device. -## A private PyTorch build from https://aiinfra.visualstudio.com/Lotus/_git/pytorch (ORTTraining branch) is needed to run the demo. - -## Model testing is not complete. - -import argparse -import os - -import numpy as np # noqa: F401 -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim # noqa: F401 -from mpi4py import MPI -from torchvision import datasets, transforms - -from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer - -try: # noqa: SIM105 - from onnxruntime.capi._pybind_state import set_cuda_device_id -except ImportError: - pass - - -class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, x): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - return out - - -def my_loss(x, target): - return F.nll_loss(F.log_softmax(x, dim=1), target) - - -def train_with_trainer(args, trainer, device, train_loader, epoch): - for batch_idx, (data, target) in enumerate(train_loader): - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - - learning_rate = torch.tensor([args.lr]) - loss = trainer.train_step(data, target, learning_rate) - - # Since the output corresponds to [loss_desc, probability_desc], the first value is taken as loss. - if batch_idx % args.log_interval == 0: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, - batch_idx * len(data), - len(train_loader.dataset), - 100.0 * batch_idx / len(train_loader), - loss[0], - ) - ) - - -# TODO: comple this once ORT training can do evaluation. -def test_with_trainer(args, trainer, device, test_loader): - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - output = F.log_softmax(trainer.eval_step(data, fetches=["probability"]), dim=1) - test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print( - "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( - test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) - ) - ) - - -def mnist_model_description(): - input_desc = IODescription("input1", ["batch", 784], torch.float32) - label_desc = IODescription( - "label", - [ - "batch", - ], - torch.int64, - num_classes=10, - ) - loss_desc = IODescription("loss", [], torch.float32) - probability_desc = IODescription("probability", ["batch", 10], torch.float32) - return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc]) - - -def main(): - # Training settings - parser = argparse.ArgumentParser(description="PyTorch MNIST Example") - parser.add_argument( - "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" - ) - parser.add_argument("--epochs", type=int, default=10, metavar="N", help="number of epochs to train (default: 10)") - parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=10, - metavar="N", - help="how many batches to wait before logging training status", - ) - - args = parser.parse_args() - use_cuda = not args.no_cuda and torch.cuda.is_available() - - torch.manual_seed(args.seed) - - kwargs = {"num_workers": 0, "pin_memory": True} - train_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "../data", - train=True, - download=True, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.batch_size, - shuffle=True, - **kwargs, - ) - test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "../data", - train=False, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.test_batch_size, - shuffle=True, - **kwargs, - ) - - comm = MPI.COMM_WORLD - args.local_rank = ( - int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) if ("OMPI_COMM_WORLD_LOCAL_RANK" in os.environ) else 0 - ) - args.world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) if ("OMPI_COMM_WORLD_RANK" in os.environ) else 0 - args.world_size = comm.Get_size() - if use_cuda: - torch.cuda.set_device(args.local_rank) - device = torch.device("cuda", args.local_rank) - args.n_gpu = 1 - set_cuda_device_id(args.local_rank) - else: - device = torch.device("cpu") - - input_size = 784 - hidden_size = 500 - num_classes = 10 - model = NeuralNet(input_size, hidden_size, num_classes) - - model_desc = mnist_model_description() - # use log_interval as gradient accumulate steps - trainer = ORTTrainer( - model, - my_loss, - model_desc, - "SGDOptimizer", - None, - IODescription( - "Learning_Rate", - [ - 1, - ], - torch.float32, - ), - device, - 1, - args.world_rank, - args.world_size, - use_mixed_precision=False, - allreduce_post_accumulation=True, - ) - print("\nBuild ort model done.") - - for epoch in range(1, args.epochs + 1): - train_with_trainer(args, trainer, device, train_loader, epoch) - test_with_trainer(args, trainer, device, test_loader) - - -if __name__ == "__main__": - main() diff --git a/samples/python/training/orttrainer/mnist/mnist_original.onnx b/samples/python/training/orttrainer/mnist/mnist_original.onnx deleted file mode 100644 index 15931affb5ccf..0000000000000 Binary files a/samples/python/training/orttrainer/mnist/mnist_original.onnx and /dev/null differ diff --git a/samples/python/training/orttrainer/mnist/ort_mnist.py b/samples/python/training/orttrainer/mnist/ort_mnist.py deleted file mode 100644 index 8f8ccf373ccf6..0000000000000 --- a/samples/python/training/orttrainer/mnist/ort_mnist.py +++ /dev/null @@ -1,174 +0,0 @@ -# This code is from https://github.com/pytorch/examples/blob/master/mnist/main.py -# with modification to do training using onnxruntime as backend on cuda device. - -import argparse -import os - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchvision import datasets, transforms - -import onnxruntime -from onnxruntime.training import ORTTrainer, ORTTrainerOptions, optim - - -# Pytorch model -class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, input1): - out = self.fc1(input1) - out = self.relu(out) - out = self.fc2(out) - return out - - -# ONNX Runtime training -def mnist_model_description(): - return { - "inputs": [("input1", ["batch", 784]), ("label", ["batch"])], - "outputs": [("loss", [], True), ("probability", ["batch", 10])], - } - - -def my_loss(x, target): - return F.nll_loss(F.log_softmax(x, dim=1), target) - - -# Helpers -def train(log_interval, trainer, device, train_loader, epoch, train_steps): - for batch_idx, (data, target) in enumerate(train_loader): - if batch_idx == train_steps: - break - - # Fetch data - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - - # Train step - loss, prob = trainer.train_step(data, target) - - # Stats - if batch_idx % log_interval == 0: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, batch_idx * len(data), len(train_loader.dataset), 100.0 * batch_idx / len(train_loader), loss - ) - ) - - -def test(trainer, device, test_loader): - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - - # Using fetches around without eval_step to not pass 'target' as input - trainer._train_step_info.fetches = ["probability"] - output = F.log_softmax(trainer.eval_step(data), dim=1) - trainer._train_step_info.fetches = [] - - # Stats - test_loss += F.nll_loss(output, target, reduction="sum").item() - pred = output.argmax(dim=1, keepdim=True) - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print( - "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( - test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) - ) - ) - - -def main(): - # Training settings - parser = argparse.ArgumentParser(description="ONNX Runtime MNIST Example") - parser.add_argument( - "--train-steps", - type=int, - default=-1, - metavar="N", - help="number of steps to train. Set -1 to run through whole dataset (default: -1)", - ) - parser.add_argument( - "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" - ) - parser.add_argument("--epochs", type=int, default=1, metavar="N", help="number of epochs to train (default: 1)") - parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=10, - metavar="N", - help="how many batches to wait before logging training status", - ) - parser.add_argument("--save-path", type=str, default="", help="Path for Saving the current Model state") - - # Basic setup - args = parser.parse_args() - if not args.no_cuda and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - torch.manual_seed(args.seed) - onnxruntime.set_seed(args.seed) - - # Data loader - train_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "./data", - train=True, - download=True, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.batch_size, - shuffle=True, - ) - - if args.test_batch_size > 0: - test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "./data", - train=False, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.test_batch_size, - shuffle=True, - ) - - # Modeling - model = NeuralNet(784, 500, 10) - model_desc = mnist_model_description() - optim_config = optim.SGDConfig(lr=args.lr) - opts = {"device": {"id": device}} - opts = ORTTrainerOptions(opts) - - trainer = ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) - - # Train loop - for epoch in range(1, args.epochs + 1): - train(args.log_interval, trainer, device, train_loader, epoch, args.train_steps) - if args.test_batch_size > 0: - test(trainer, device, test_loader) - - # Save model - if args.save_path: - torch.save(model.state_dict(), os.path.join(args.save_path, "mnist_cnn.pt")) - - -if __name__ == "__main__": - main() diff --git a/samples/python/training/orttrainer/mnist/pytorch_mnist.py b/samples/python/training/orttrainer/mnist/pytorch_mnist.py deleted file mode 100644 index 2e451d85f62e8..0000000000000 --- a/samples/python/training/orttrainer/mnist/pytorch_mnist.py +++ /dev/null @@ -1,157 +0,0 @@ -import argparse -import os - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torchvision import datasets, transforms - - -# Pytorch model -class NeuralNet(nn.Module): - def __init__(self, input_size, hidden_size, num_classes): - super().__init__() - self.fc1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, num_classes) - - def forward(self, input1): - out = self.fc1(input1) - out = self.relu(out) - out = self.fc2(out) - return out - - -def my_loss(x, target, is_train=True): - if is_train: - return F.nll_loss(F.log_softmax(x, dim=1), target) - else: - return F.nll_loss(F.log_softmax(x, dim=1), target, reduction="sum") - - -# Helpers -def train(args, model, device, train_loader, optimizer, epoch): - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - if batch_idx == args.train_steps: - break - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - optimizer.zero_grad() - output = model(data) - loss = my_loss(output, target) - loss.backward() - optimizer.step() - if batch_idx % args.log_interval == 0: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, - batch_idx * len(data), - len(train_loader.dataset), - 100.0 * batch_idx / len(train_loader), - loss.item(), - ) - ) - - -def test(model, device, test_loader): - model.eval() - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - data, target = data.to(device), target.to(device) # noqa: PLW2901 - data = data.reshape(data.shape[0], -1) # noqa: PLW2901 - output = model(data) - # Stats - test_loss += my_loss(output, target, False).item() - pred = output.argmax(dim=1, keepdim=True) - correct += pred.eq(target.view_as(pred)).sum().item() - - test_loss /= len(test_loader.dataset) - - print( - "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( - test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) - ) - ) - - -def main(): - # Training settings - parser = argparse.ArgumentParser(description="PyTorch MNIST Example") - parser.add_argument( - "--train-steps", - type=int, - default=-1, - metavar="N", - help="number of steps to train. Set -1 to run through whole dataset (default: -1)", - ) - parser.add_argument( - "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" - ) - parser.add_argument("--epochs", type=int, default=1, metavar="N", help="number of epochs to train (default: 1)") - parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=10, - metavar="N", - help="how many batches to wait before logging training status", - ) - parser.add_argument("--save-path", type=str, default="", help="Path for Saving the current Model") - - # Basic setup - args = parser.parse_args() - if not args.no_cuda and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - torch.manual_seed(args.seed) - - # Data loader - train_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "./data", - train=True, - download=True, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.batch_size, - shuffle=True, - ) - - if args.test_batch_size > 0: - test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "./data", - train=False, - transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), - ), - batch_size=args.test_batch_size, - shuffle=True, - ) - - # Modeling - model = NeuralNet(784, 500, 10).to(device) - optimizer = optim.SGD(model.parameters(), lr=args.lr) - - # Train loop - for epoch in range(1, args.epochs + 1): - train(args, model, device, train_loader, optimizer, epoch) - if args.test_batch_size > 0: - test(model, device, test_loader) - - # Save model - if args.save_path: - torch.save(model.state_dict(), os.path.join(args.save_path, "mnist_cnn.pt")) - - -if __name__ == "__main__": - main() diff --git a/samples/python/training/orttrainer/pytorch_transformer/README.md b/samples/python/training/orttrainer/pytorch_transformer/README.md deleted file mode 100644 index cda8cba6ca0ad..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/README.md +++ /dev/null @@ -1,33 +0,0 @@ -# TransformerModel example - -This example was adapted from Pytorch's [Sequence-to-Sequence Modeling with nn.Transformer and TorchText](https://pytorch.org/tutorials/beginner/transformer_tutorial.html) tutorial - -## Requirements - -* PyTorch 1.6+ -* TorchText 0.6+ -* ONNX Runtime 1.5+ - -## Running PyTorch version - -```bash -python pt_train.py -``` - -## Running ONNX Runtime version - -```bash -python ort_train.py -``` - -## Optional arguments - -| Argument | Description | Default | -| :---------------- | :-----------------------------------------------------: | --------: | -| --batch-size | input batch size for training | 20 | -| --test-batch-size | input batch size for testing | 20 | -| --epochs | number of epochs to train | 2 | -| --lr | learning rate | 0.001 | -| --no-cuda | disables CUDA training | False | -| --seed | random seed | 1 | -| --log-interval | how many batches to wait before logging training status | 200 | diff --git a/samples/python/training/orttrainer/pytorch_transformer/ort_train.py b/samples/python/training/orttrainer/pytorch_transformer/ort_train.py deleted file mode 100644 index 551e878cc9035..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/ort_train.py +++ /dev/null @@ -1,89 +0,0 @@ -import argparse - -import torch -from ort_utils import my_loss, transformer_model_description_dynamic_axes -from pt_model import TransformerModel -from utils import get_batch, prepare_data - -import onnxruntime - - -def train(trainer, data_source, device, epoch, args, bptt=35): - total_loss = 0.0 - for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)): - data, targets = get_batch(data_source, i) - - loss, pred = trainer.train_step(data, targets) - total_loss += loss.item() - if batch % args.log_interval == 0 and batch > 0: - cur_loss = total_loss / args.log_interval - print( - "epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}".format( - epoch, batch, len(data_source) // bptt, cur_loss - ) - ) - total_loss = 0 - - -def evaluate(trainer, data_source, bptt=35): - total_loss = 0.0 - with torch.no_grad(): - for i in range(0, data_source.size(0) - 1, bptt): - data, targets = get_batch(data_source, i) - loss, pred = trainer.eval_step(data, targets) - total_loss += len(data) * loss.item() - return total_loss / (len(data_source) - 1) - - -if __name__ == "__main__": - # Training settings - parser = argparse.ArgumentParser(description="PyTorch TransformerModel example") - parser.add_argument( - "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=20, metavar="N", help="input batch size for testing (default: 20)" - ) - parser.add_argument("--epochs", type=int, default=2, metavar="N", help="number of epochs to train (default: 2)") - parser.add_argument("--lr", type=float, default=0.001, metavar="LR", help="learning rate (default: 0.001)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=200, - metavar="N", - help="how many batches to wait before logging training status (default: 200)", - ) - - # Basic setup - args = parser.parse_args() - if not args.no_cuda and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - torch.manual_seed(args.seed) - onnxruntime.set_seed(args.seed) - - # Model - optim_config = onnxruntime.training.optim.SGDConfig(lr=args.lr) - model_desc = transformer_model_description_dynamic_axes() - model = TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device) - - # Preparing data - train_data, val_data, test_data = prepare_data(device, args.batch_size, args.test_batch_size) - trainer = onnxruntime.training.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss) - - # Train - for epoch in range(1, args.epochs + 1): - train(trainer, train_data, device, epoch, args) - val_loss = evaluate(trainer, val_data) - print("-" * 89) - print(f"| end of epoch {epoch:3d} | valid loss {val_loss:5.2f} | ") - print("-" * 89) - - # Evaluate - test_loss = evaluate(trainer, test_data) - print("=" * 89) - print(f"| End of training | test loss {test_loss:5.2f}") - print("=" * 89) diff --git a/samples/python/training/orttrainer/pytorch_transformer/ort_utils.py b/samples/python/training/orttrainer/pytorch_transformer/ort_utils.py deleted file mode 100644 index 73992f5596f5f..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/ort_utils.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch - -from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription -from onnxruntime.capi.ort_trainer import ModelDescription as Legacy_ModelDescription - - -def my_loss(x, target): - x = x.view(-1, 28785) - return torch.nn.CrossEntropyLoss()(x, target) - - -def transformer_model_description(bptt=35, batch_size=20, ntokens=28785): - model_desc = { - "inputs": [("input1", [bptt, batch_size]), ("label", [bptt * batch_size])], - "outputs": [("loss", [], True), ("predictions", [bptt, batch_size, ntokens])], - } - return model_desc - - -def transformer_model_description_dynamic_axes(ntokens=28785): - model_desc = { - "inputs": [("input1", ["bptt", "batch_size"]), ("label", ["bptt_x_batch_size"])], - "outputs": [("loss", [], True), ("predictions", ["bptt", "batch_size", ntokens])], - } - return model_desc - - -def legacy_transformer_model_description(bptt=35, batch_size=20, ntokens=28785): - input_desc = Legacy_IODescription("input1", [bptt, batch_size]) - label_desc = Legacy_IODescription("label", [bptt * batch_size]) - loss_desc = Legacy_IODescription("loss", []) - predictions_desc = Legacy_IODescription("predictions", [bptt, batch_size, ntokens]) - return ( - Legacy_ModelDescription([input_desc, label_desc], [loss_desc, predictions_desc]), - Legacy_IODescription("__learning_rate", [1]), - ) - - -def legacy_transformer_model_description_dynamic_axes(ntokens=28785): - input_desc = Legacy_IODescription("input1", ["bptt", "batch_size"]) - label_desc = Legacy_IODescription("label", ["bptt_x_batch_size"]) - loss_desc = Legacy_IODescription("loss", []) - predictions_desc = Legacy_IODescription("predictions", ["bptt", "batch_size", ntokens]) - return ( - Legacy_ModelDescription([input_desc, label_desc], [loss_desc, predictions_desc]), - Legacy_IODescription("__learning_rate", [1]), - ) diff --git a/samples/python/training/orttrainer/pytorch_transformer/pt_model.py b/samples/python/training/orttrainer/pytorch_transformer/pt_model.py deleted file mode 100644 index 4f2e03192c6cf..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/pt_model.py +++ /dev/null @@ -1,62 +0,0 @@ -import math - -import torch -import torch.nn as nn - - -class TransformerModel(nn.Module): - def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): - super().__init__() - from torch.nn import TransformerEncoder, TransformerEncoderLayer - - self.model_type = "Transformer" - self.input1_mask = None - self.pos_encoder = PositionalEncoding(ninp, dropout) - encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) - self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) - self.encoder = nn.Embedding(ntoken, ninp) - self.ninp = ninp - self.decoder = nn.Linear(ninp, ntoken) - - self.init_weights() - - def _generate_square_subsequent_mask(self, sz): - mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) - mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, 0.0) - return mask - - def init_weights(self): - initrange = 0.1 - self.encoder.weight.data.uniform_(-initrange, initrange) - self.decoder.bias.data.zero_() - self.decoder.weight.data.uniform_(-initrange, initrange) - - def forward(self, input1): - if self.input1_mask is None or self.input1_mask.size(0) != input1.size(0): - device = input1.device - mask = self._generate_square_subsequent_mask(input1.size(0)).to(device) - self.input1_mask = mask - - input1 = self.encoder(input1) * math.sqrt(self.ninp) - input1 = self.pos_encoder(input1) - output = self.transformer_encoder(input1, self.input1_mask) - output = self.decoder(output) - return output - - -class PositionalEncoding(nn.Module): - def __init__(self, d_model, dropout=0.1, max_len=5000): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - - pe = torch.zeros(max_len, d_model) - position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0).transpose(0, 1) - self.register_buffer("pe", pe) - - def forward(self, x): - x = x + self.pe[: x.size(0), :] - return self.dropout(x) diff --git a/samples/python/training/orttrainer/pytorch_transformer/pt_train.py b/samples/python/training/orttrainer/pytorch_transformer/pt_train.py deleted file mode 100644 index a197fb50357e9..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/pt_train.py +++ /dev/null @@ -1,94 +0,0 @@ -import argparse - -import torch -import torch.nn as nn -from pt_model import TransformerModel -from utils import get_batch, prepare_data - - -def train(model, data_source, device, epoch, args, bptt=35): - total_loss = 0.0 - model.train() - for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)): - data, targets = get_batch(data_source, i) - - optimizer.zero_grad() - output = model(data) - loss = criterion(output.view(-1, 28785), targets) - loss.backward() - optimizer.step() - - total_loss += loss.item() - if batch % args.log_interval == 0 and batch > 0: - cur_loss = total_loss / args.log_interval - print( - "epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}".format( - epoch, batch, len(data_source) // bptt, cur_loss - ) - ) - total_loss = 0 - - -def evaluate(model, data_source, criterion, bptt=35): - total_loss = 0.0 - model.eval() - with torch.no_grad(): - for i in range(0, data_source.size(0) - 1, bptt): - data, targets = get_batch(data_source, i) - output = model(data) - output_flat = output.view(-1, 28785) - total_loss += len(data) * criterion(output_flat, targets).item() - return total_loss / (len(data_source) - 1) - - -if __name__ == "__main__": - # Training settings - parser = argparse.ArgumentParser(description="PyTorch TransformerModel example") - parser.add_argument( - "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" - ) - parser.add_argument( - "--test-batch-size", type=int, default=20, metavar="N", help="input batch size for testing (default: 20)" - ) - parser.add_argument("--epochs", type=int, default=2, metavar="N", help="number of epochs to train (default: 2)") - parser.add_argument("--lr", type=float, default=0.001, metavar="LR", help="learning rate (default: 0.001)") - parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=200, - metavar="N", - help="how many batches to wait before logging training status (default: 200)", - ) - - # Basic setup - args = parser.parse_args() - if not args.no_cuda and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - torch.manual_seed(args.seed) - - # Model - criterion = nn.CrossEntropyLoss() - lr = 0.001 - model = TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device) - optimizer = torch.optim.SGD(model.parameters(), lr=lr) - - # Preparing data - train_data, val_data, test_data = prepare_data(device, args.batch_size, args.test_batch_size) - - # Train - for epoch in range(1, args.epochs + 1): - train(model, train_data, device, epoch, args) - val_loss = evaluate(model, val_data, criterion) - print("-" * 89) - print(f"| end of epoch {epoch:3d} | valid loss {val_loss:5.2f} | ") - print("-" * 89) - - # Evaluate - test_loss = evaluate(model, test_data, criterion) - print("=" * 89) - print(f"| End of training | test loss {test_loss:5.2f}") - print("=" * 89) diff --git a/samples/python/training/orttrainer/pytorch_transformer/utils.py b/samples/python/training/orttrainer/pytorch_transformer/utils.py deleted file mode 100644 index 3be8b6cf3f420..0000000000000 --- a/samples/python/training/orttrainer/pytorch_transformer/utils.py +++ /dev/null @@ -1,59 +0,0 @@ -import os - -import torch -from torchtext.data.utils import get_tokenizer -from torchtext.utils import download_from_url, extract_archive -from torchtext.vocab import build_vocab_from_iterator - - -def batchify(data, bsz, device): - # Divide the dataset into bsz parts. - nbatch = data.size(0) // bsz - # Trim off any extra elements that wouldn't cleanly fit (remainders). - data = data.narrow(0, 0, nbatch * bsz) - # Evenly divide the data across the bsz batches. - data = data.view(bsz, -1).t().contiguous() - return data.to(device) - - -def get_batch(source, i, bptt=35): - seq_len = min(bptt, len(source) - 1 - i) - data = source[i : i + seq_len] - target = source[i + 1 : i + 1 + seq_len].view(-1) - return data, target - - -def prepare_data(device="cpu", train_batch_size=20, eval_batch_size=20, data_dir=None): - url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip" - - download_path = ".data_wikitext_2_v1" - extract_path = None - if data_dir: - download_path = os.path.join(data_dir, "download") - os.makedirs(download_path, exist_ok=True) - download_path = os.path.join(download_path, "wikitext-2-v1.zip") - - extract_path = os.path.join(data_dir, "extracted") - os.makedirs(extract_path, exist_ok=True) - - test_filepath, valid_filepath, train_filepath = extract_archive( - download_from_url(url, root=download_path), to_path=extract_path - ) - tokenizer = get_tokenizer("basic_english") - vocab = build_vocab_from_iterator(map(tokenizer, iter(open(train_filepath, encoding="utf8")))) # noqa: SIM115 - - def data_process(raw_text_iter): - data = [torch.tensor([vocab[token] for token in tokenizer(item)], dtype=torch.long) for item in raw_text_iter] - return torch.cat(tuple(filter(lambda t: t.numel() > 0, data))) - - train_data = data_process(iter(open(train_filepath, encoding="utf8"))) # noqa: SIM115 - val_data = data_process(iter(open(valid_filepath, encoding="utf8"))) # noqa: SIM115 - test_data = data_process(iter(open(test_filepath, encoding="utf8"))) # noqa: SIM115 - - device = torch.device(device) - - train_data = batchify(train_data, train_batch_size, device) - val_data = batchify(val_data, eval_batch_size, device) - test_data = batchify(test_data, eval_batch_size, device) - - return train_data, val_data, test_data diff --git a/setup.py b/setup.py index 1c04433c9a7ca..da4943c4ef7ae 100644 --- a/setup.py +++ b/setup.py @@ -398,7 +398,6 @@ def finalize_options(self): "onnxruntime", "onnxruntime.backend", "onnxruntime.capi", - "onnxruntime.capi.training", "onnxruntime.datasets", "onnxruntime.tools", "onnxruntime.tools.mobile_helpers", diff --git a/tools/android_custom_build/Dockerfile b/tools/android_custom_build/Dockerfile index 66b6a36e5a8c0..754a6633b0c62 100644 --- a/tools/android_custom_build/Dockerfile +++ b/tools/android_custom_build/Dockerfile @@ -55,7 +55,7 @@ WORKDIR /workspace # install Android SDK and tools ENV ANDROID_HOME=~/android-sdk -ENV NDK_VERSION=26.0.10792818 +ENV NDK_VERSION=26.1.10909125 ENV ANDROID_NDK_HOME=${ANDROID_HOME}/ndk/${NDK_VERSION} RUN aria2c -q -d /tmp -o cmdline-tools.zip \ diff --git a/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml b/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml index 8cc7f63a193cc..b8dba89b0b899 100644 --- a/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/use-android-ndk.yml @@ -3,7 +3,7 @@ parameters: - name: AndroidNdkVersion type: string - default: "26.0.10792818" # LTS version + default: "26.1.10909125" # LTS version steps: - bash: | diff --git a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml index c649883ea0d8b..9982b36509b68 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml @@ -65,7 +65,6 @@ stages: clean: all steps: - checkout: self - fetchDepth: 1 submodules: false - script: | git submodule sync -- cmake/external/onnx